{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 数据处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 安装类库\n",
    "# !mkdir /home/aistudio/external-libraries\n",
    "# !pip install imgaug -t /home/aistudio/external-libraries\n",
    "import sys\n",
    "sys.path.append('/home/aistudio/external-libraries')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "image shape: (32, 32, 3)\n",
      "label value: horse\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADFCAYAAAARxr1AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAHC1JREFUeJztnXuMXHd1x79nZu689732erO24zixnZg8XGEgFNTybAOqFKhaBH9U+SMCKoFUBP9EVGqp1EpUKqD+UVEFNSKVKIEWUFIaKCGiStNAHiTEeZgkdmJnba937X3NzO687+kfMw47+z17Pdldr3fN+UjW7h7fufd378yZe89bVBWO49jELvcCHGcz4wriOBG4gjhOBK4gjhOBK4jjROAK4jgRuII4TgSuII4TwZoURERuE5GXROSYiNy1XotynM2CrDaSLiJxAC8D+CCAUwCeBPAJVX1xpdfEg4QGqaBDlslkaLtYTEhWqVRIptowjxOG/Pp4LE6yIJkgWT7XY+2RJLXFKskaTd5OY/b1jRtfTf2D20iWyeR4NSHvs1xeNGQLJEun+XpjhY9Ao9kkWRAEJEum0vYO6DB8IDWurSrLwpBlANA0rnmjXiNZPN75/k9PnkepUOQPyjL4E9I9bwdwTFVfBQARuQ/A7QBWVJAgFWDnjXs7ZDfd9BbaLpNJkuyV40dJVq/NmseplPj1OeODv2v3MMnecevvkUxCvuCnnjlOsvOzJZI1svxaAMjz5x5//PE/J9lbbnonycqLvM8jzz9FsueeY9nBA3y9LYUDgNm5Asm2je4k2dV795FMhL8B6k3+UqmDv/iqTVbshcWyucbSLMunz46TrL833/H3333+S+b+lrOWR6wxAEtXcqot60BEPiUiT4nIU806fyM5zmbmkhvpqnq3qh5W1cPxgB9zHGczs5ZHrNMAdi35e2dbtiIahqhXOm+JszPnaLvkju0kGx1h2bmpunmcuXN8i85m2V5ZWJwh2anTL5NsbIRujEinUyRTnSfZ4NCAuca3HtrPxxkdJVnTeJ6ePc/X7OWjL5FsboYfkRYW+HGqt99eY98wPwdWQ76O52fPkyyMZ0kmYtkLbDtVq/yoWq7aj6qLNd5nkOZ1yzL756LGR5u13EGeBLBPRK4RkSSAjwN4YA37c5xNx6rvIKraEJHPAvhvAHEA96jqC+u2MsfZBKzlEQuq+iCAB9dpLY6z6fBIuuNEsKY7yJslnohjaLivQ1avsh87YQTXioU5klWMeAAADA8MsWw4T7KbbmGfPowg5dnJUyTLBmykZ7MchGs07DWm0xyrSSX42DHl2EEixobywQPXkWz3VVeRLNfDwcggawVHgcoiOx1qjSLJ5ubZ2TGzwO+X5XAIhM9lscD7q4Z2iKBoxA8b81Mkyy6Lt1SNwLOF30EcJwJXEMeJwBXEcSJwBXGcCDbUSE+nUrju2j0dsiDJ6SepFMvm56dJ1qzb2bxqJBcOD/WT7PoDe0k2PcsJkM8+8wqvMdhBsr7+XpIVY2ysAnZUOQZedyrBstERjhTnMntI9sJznBXQk+HM2zDJDgcAiNVYnkuwsZxK8vdsocTnVyxwhoOEbCxXipwBsNiwsyZmjIzjsMjv4WK9MzrfWOGzsxy/gzhOBK4gjhOBK4jjROAK4jgRbKyRnk7hwP5rO2S5HEef6w023IoFjo42a3Z0dWqCX79jhKsHBw2jenqanQGJBDsNjII5hIYhGQ/s76Aw5G1j4KyCVJIj6VYFYCPNsljIr80aRvZ8jVPOAaA4z8ZyT9pwLoR8jqkGy3LCH7eZeY7MWxF3GE4NAIgpX8dymdPlK4VOw90qJzb339VWjvNbiiuI40TgCuI4EbiCOE4EriCOE8GavFgicgJAEUATQENVD19ke6SWNWtLpXgJVg+rgzccIFk6sJd/8jinGoyNctOHIMGvtxrMZY30jHiTv1uGhvpIFvbY3pJUiutBmk32vtSqRhO8uLFuw0OUCtjzk0/yeqbO83EB4PzZsyyrsmercJZ7dZQWjfUY1/HE+DGS7drDdSw92zlVCADiFfZ4FYxGIPW5ztqWZqO7VJP1cPO+V1W5rYXjXAH4I5bjRLBWBVEAPxGRX4rIp6wNlnZWXCjZASnH2ays9RHr3ap6WkS2A3hIRH6tqo8s3UBV7wZwNwCM7R71mdPOlmKtbX9Ot39OicgP0Gpo/chK28cTCfQNdqZ81BucXlGscPrBVbvYcOtL2c0GqsUTJBscYCMvFuMbaCbLxvPAIDd86AlZtnPn1SQLM7YxmO8zajCMdn+lBW6cEDcaSzRqnLJjpchUjQ6FzzzxS3ONP3/61ySrL/B7M33mdX5xjGtWtg9xus/4aTbSc3l+X7btto30nGFsx4x6oHBZR8hupxqs+hFLRHIi0nPhdwB/AOD51e7PcTYja7mDjAD4gYhc2M+/qeqP12VVjrNJWEvr0VcB3LKOa3GcTYe7eR0ngg2tB4kl4sgNDXbIzkycoO3OzXPccWgHt+iPiz36K2M0Ieg3WvxncmyQDxl1I8UFrqvIGEb60DAfI8jbl1gDNi6TaXY61OtsFFtdAWtlNtwLJa6VeOKJX5DsJw/afpXZOTZ2G4aRX2uywRsPeLuRPv4+Tjf5PTh/hiP4e27gJhkA0GvU2+SM0XoVGvV2iY10x/ltwBXEcSJwBXGcCFxBHCeCDTXSIYBkOtOtYxk2LhNGNHtukY1VNYxDAIgbowV6BzgSn+/laO9rpzhV/twkp4hnhQ3gA9ewkT02yvMNAWCxyeeTSPB6YtY45TpHyE+f4pEBD/2EDfJnn+Xo+My0nSOXTnH6ftMYV2AFpStlw7mwwGn1mRjPXS+e59eePn7GXOPYGI9zyKZ4PuJcrHOf1phqC7+DOE4EriCOE4EriONE4AriOBFsbCQ9rgjyncbtjt2DtF3/do5Sq5XWXLaN9FyMU6Mnpzmye+Qod2t87DFOSD53miP7KWXjcnGGL+d7/9AeLbDnwC6SBXHeJ2KcLTB+jlPgf/xfPyfZIz/jNPZa0xgjYJwLANTrvG29wVkFVaM8oVbj2vXZWXYu5AM+v1jduLbG+wcAsxk+djLJRnoq0/mZEqPUwcLvII4TgSuI40TgCuI4EbiCOE4EFzXSReQeAH8EYEpVb2zLBgF8B8AeACcAfExVOQS9jFAbqDQ6N0ulOXocpA0jK8HGbio0irgBLM5wJP34S9xM7MXnx0k2foKN0ECHSDY7x4byo2d/xdsVuOYeAMb2cvp2b5bPJybsiDh69FWSPf5/L5Bs0QiQJ1NsFDdWyEioGvMDq1WeM1iu8BzGWIJfO7c4SbIwxSMotvezAyMd8GcCAAaG2MlTNcZQ9IWd6e4Jo2mgRTd3kG8CuG2Z7C4AD6vqPgAPt/92nCuOiypIu43P8kSf2wHc2/79XgAfWed1Oc6mYLU2yIiqTrR/P4tWAweTpY3jinN8e3aczcyajXRtNRhasX5RVe9W1cOqerinn+0Nx9nMrDaSPikio6o6ISKjADgkbdBo1HFuqrMT+LZhTkNv1NlYLRr914KQI7MAcPQZnjOICjd1G+jlVPT5XjbSk0Y0e7ZiNG9juxTT5+xU8mdffJRkxblTJDO/wYzIdyLOqfZ9vbzuUDl1X41GawBQq7ODoW7Jmvxk8NZ33UiyXaOjJHv28SMkKzU5Cl+Hvcbd+/k9jBkTA0bLne/rI//5U3N/tK+utmIeAHBH+/c7ANy/yv04zqbmogoiIt8G8HMAB0TklIjcCeDLAD4oIq8A+ED7b8e54rjoI5aqfmKF/3r/Oq/FcTYdHkl3nAg2NN292WiiNNtpgMUavIRGg0eHzc5w5Lo4YxtuR37BUfMbr+URbHWjIdz0JDct6+/lZnK5PHdTLysb5PkcR3oBIDHN9dnlEn9fNRt8jrkcHzubZsM9k+GMgmKRzzkmfL0BIJ1mI79mOFD6hrk84a3vexvJ9h+4hmQDu7jZ3q+feY1khYadqCF5jppXxYjilzqj/U3tbgSb30EcJwJXEMeJwBXEcSJwBXGcCFxBHCeCjfVi1ZuYmez0JhSmOU2hYTRomJpkz1S5YKeaTIyzx6t6/imShVVOISsWjayZkNMzrtrOnq258+zFajT4tQDQ28OvX+znGpGFItdaiPG2NZrs7Yon2OMUGKMhEgl7jERvH3vB4lPcJGH3wb0ku/7wQZJl8vx+Hfr9m0m2fYzTj8oLdqJrWYzmEMZcx3OLndexERrNKwz8DuI4EbiCOE4EriCOE4EriONEsKFGepAMMDq2s0MWj/MSyiVOw0iBDbyzdTbQAKBhGHQnTj7N60kY8+16uYlAIsWGdt0YX1AqcR1KKrXfXOPeEa5PqRsNERp1TgOxjO+E0RwxNL7+8oOcFjKQ5RECAJBO8XGu2suG++9+6ADJ9ozy6ISGsmGsed5fj5XGU7WbX4QBf1Z6E7zPhnTuM2FcQwu/gzhOBK4gjhOBK4jjROAK4jgRrLaz4pcAfBLAhfD2F1X1wYvtK5VO49r9+zr3H2PjOy28rJRRsjD+GnfqA4AnHzLqCco8w8/q+t8/xMZlT54N26mJCZLVa5wB0GzatRbWPnft3E0y6xtMwcauGDUdxQWOwvf0sAE8Mmwb6YFxfXbdzI0Xxq7jcRNiRPaTcX6v40k+iHXNYkn7OlaMcQ75OGdIxJcdO2E4hyxW21kRAL6mqofa/y6qHI6zFVltZ0XH+a1gLTbIZ0XkiIjcIyJcN9lmaWfFwqx3VnS2FqtVkK8DuBbAIQATAL6y0oZLOyv2DnhnRWdrsapIuqq+YR2LyDcA/LCb1zWbTRSLnZHPhGGklSps4KWNsPDcrB1Jr1aM2Xp1Y8ahFUnPGI0KFrgJwNkz3NxBjZl+J4+fNNdYNjoz5rOcij62g9PiFRzZr1Q51f7ka6+QLJviYyyk7Sfo/DB/oSX7uIPjbJmPrWU2qhPGDMbAMNyThndAQnuOYgzGWITQyD6QZe9/d4H01d1B2u1GL/BRADz50nGuALpx834bwHsADIvIKQB/DeA9InIIrabVJwB8+hKu0XEuG6vtrPgvl2AtjrPp8Ei640SwoenuoSrK1U6DN6yyAZxPGnP0jLrwhUU7BbpaZWO5abzeqhefL7DBGoYcmZ2b5/mGFcNYDWPcyRAAEkZkuGGk2sfFqivnt21+jjsPnp/kaH82zt+J9Yp9Hfvi7L3fBc4ACMt8fZqGoZywMgCMqHnGSFdvjaFhmgHLJcay2LKMDVFPd3ecNeMK4jgRuII4TgSuII4TwYYa6QAgyyKnWaPFfjpgIy2jrMuJuDGLEECzueJM0Q7qxmiB6WmOkI/tYsP0lrdxczOJs+F3Zpwb3gHAyVNPkiyTYaO4boyCSKf4mtWNKH61xpkGxTmOpDcCY7gigL4UR6nDODs2qg2+3g3DSA/rfL1jhpFeEz6X8qK9Rg34mieTnDWRS3duZzleLPwO4jgRuII4TgSuII4TgSuI40SwoUa6GpH0uhqpyYb9FMTZcIdRzw4AYjSZswKxahiSuTwbwB/90w+QbO8NbKTHU7zGI0+9YK7xpz96lGTnjbmFzTobxc2YUfsOI3KdYdlChQ337UOcUg8AB27m5na9/Wy4F6rcRM+aAVirsaFdN0obrBT4utHxHwDCqpXaztkL1WWN4pordN1fjt9BHCcCVxDHicAVxHEicAVxnAi6qSjcBeBfAYygVUF4t6r+o4gMAvgOgD1oVRV+TFXtae9tFECITmu50WTjS40UbyvwWavahlutyoafsUuI8fWwbz+PE9t/kLuXZ4b50tWUDb/D736Hucb9+99CsvEzZ0h25ixH9qF8bA35BH90/0Mkm3iZR8wNjnCaPQBs38kN4RZKPN4sbPJ5xwI2nkWMN9H4BBaMLvdN4xgAkBI26GOGR6Za7vysrGckvQHgC6p6EMCtAD4jIgcB3AXgYVXdB+Dh9t+Oc0XRTeO4CVV9uv17EcBRAGMAbgdwb3uzewF85FIt0nEuF2/KBhGRPQB+B8DjAEZU9ULJ2lm0HsGs17zROG5hnivuHGcz07WCiEgewPcAfE5VO6JN2qqHNB/qljaOy/UZPYwcZxPTlYKISICWcnxLVb/fFk9e6I/V/mkMGHecrU03XixBq83PUVX96pL/egDAHQC+3P55fxf7QrC8a55RQ9E00k+MjAvMTdu9fms1ozmA0azAcm3tNobYZ1PcYXBhntNCakZjiOQKLfwG8lz70bOfuxaODA2SzPLAGGUjeOx/HiPZuOE17OmxuxbmjIYIzRKncRiTF8y0m5rR/TFjdVFM8zWrGesGgIRx4mJ8fpZ7T1d44OH9d7HNuwD8GYDnRORXbdkX0VKM74rInQBOAvhYV0d0nC1EN43jHsXKnUzfv77LcZzNhUfSHScCVxDHiWBjmzaoUlqC1cp/ocx6qyHXaVTmbUMrtOYCxvgpMWs0Jbj6Kp7BF6vwGmWBj50MOe0hZRW3AAgCdiRk49wcIhXjc6k0uKaj1GDjOWk4NpJxbtpwwzVj5hr3DhqpJkleY8nocDlvzGusFK2RCMZ7ZaSkiJUXBHPSAerG3MJmo3ON4QqdGpfjdxDHicAVxHEicAVxnAhcQRwngg1v2tCsdBqTGhhjCWrGDL4iyyZP210L1Yg0G+US6O/rI9mQIZs7a3RwDPnSpRNs9DcadkfAWNwwqo13I6izAVwpcwaBGk0S1Oh4GBhNKQa28zkDQDrNC4oLG/mG/wN9yhHyoQbLJie43qVhOBwq4QpNGwzj3TL8q6Vl74OVemDgdxDHicAVxHEicAVxnAhcQRwngg010mMKpJbZ2rGAo89loxnDzAQb5Oen7BIUK7PS6heQy7JRvVBkA/i1l06QLBnj1w71cYfChDEvDwDCpJFKXj1PMjWM/KqVX57hlPy48NubyhkNH8RuiLBY4JR+VTbSYaXQxzMkS6d5jWEPOwikOMfHLdtGdWG58Q0gnTQyMYqdn7NY02cUOs6acQVxnAhcQRwnAlcQx4lgLZ0VvwTgkwAuWM9fVNUHo/YVgyAbdkZTa2WOmqLE6cqL57ijX21hhSi1YaaLUX8exI3uiGXepxjp94UK11c3Cmzs9vWyYQoAkjVq8RdneJ81Pk49yfXe8STXs4sR7e/NGPMf67YBXJ/itHoFG+lqpOTPhvwe1ow5ijEjWp8O2cCPG+cMAIkGX8d4jR0g8WrnumPanZHejRfrQmfFp0WkB8AvReRCT8uvqeo/dHUkx9mCdFOTPgFgov17UUQudFZ0nCuetXRWBIDPisgREblHRLiPDTo7K5YK3lnR2VqspbPi1wFcC+AQWneYr1ivW9pZMd/rnRWdrUVXkXSrs6KqTi75/28A+OFFd9QEYsuCs0GCI+m94LrnoMLR1WqJU8EBIB5jvQ+SbFymAjb80sZ2PSlOES812blQmOdZfcV5eyJEIsavv26Uo8qZPF+LQoGj/fNnOKtgbpaN7IEcG8ADwjIAyC4YWQ7GTMHQmIVYXp4yAaBujJRU4/2PG3XziRWy061ov3FpkdTO91pW7GS1bF8X22ClzooX2o62+SiA57s6ouNsIdbSWfETInIILdfvCQCfviQrdJzLyFo6K0bGPBznSsAj6Y4TwcamuyOGnHYanaEx1y+RZG/X7iFeal/uFfM4QYIN46TRRTxhDKxfWGCjr1nj6HqQYMO2Z4Cj2aFRXw0AiwuGgyHPRnoixw6CsMIW6+lxNtKnZ9mxMTy8m2SBER0HgGST5ZUqn0+xzBkAlV6+tknj+sTT/L7US+wIaNTt62hlJKiRYLE84N5d2zi/gzhOJK4gjhOBK4jjROAK4jgRbLiRntJOAzxmdD/XGhtuuYANvH3X7jOP0zQisePjp0lWKnDk+9grx3h/Rl14Is7G8+g2zuHsMYxsANi2nUer1ZJs+JeEjV1N83YNI50/leXtakaWQV3sVHKNG6n6Rnc7aXBkv17kKH4s4NcGxrp7jcyFecNRAgDxPl57o2o0I5xflmpvHNfC7yCOE4EriONE4AriOBG4gjhOBK4gjhPBxs4oBKDxi8f800b6wZ69O0nWv2ObeYx3LrLH49H/fYxkrx1nj1WtZgyhr3KzgWrIqQ+v13i7dMZO45DcdSTLp4zzMboEJvvZO9U7xJ6fwSE+9sIie5ymZ20P0eCeHSQLA15PptpLsqDGXqJKhWVVNeYJZvlDsWCMdwCAsuGMSuc5VWm5E1SsmQ0GfgdxnAhcQRwnAlcQx4mgm5LbtIg8ISLPisgLIvI3bfk1IvK4iBwTke+IrBCOdZwtTDdGehXA+1S11G7e8KiI/AjA59FqHHefiPwzgDvR6nSyIqqKeqPT2LLqNNIpNi6DGG+X7eGGBgCwQ1nvt/fdRrLx1zn9pDfPKS1Tr79Gsvl5rrUI8mw8V0N7tEDd8E5UjHmEGuO3qFTiFJmEkbJz4/XXkKxvgJ0dg/08tgEApit8nB1Xs+FeOM0Oi8lZHuVQqLOhXTMcMlpnAzrM2N/lqQR/VqZnuT7l9MlTHX9XalxzYnHRO4i2uFBFFLT/KYD3AfiPtvxeAB/p6oiOs4XoygYRkXi7YcMUgIcAHAcwp78ZrXoKK3RbXNo4rlBiF6PjbGa6UhBVbarqIQA7AbwdwPXdHmBp47jevN3I2XE2K2/Ki6WqcwB+BuCdAPpF3pjxtRMAP9A7zhanm/EH2wDUVXVORDIAPgjg79FSlD8BcB+AOwDcf9GjqSBe7zxkLMFLiBk1CzWjaD9pl1pAjOHyO4c4Sj02OEKywjyPWehL8P5mZqdJVrIaCxj1EwAQGl9NWmFDu2p0MkwbcwL7R7gZw/49fKOvG0a/Vm2DNd/LDot8H1/06gzXrFTBEfIzU5MkKxmG++AY18oEeTbmASAEX/OUUQ/UM9DZOjpuNOyw6MaLNQrgXhGJo3XH+a6q/lBEXgRwn4j8LYBn0Oq+6DhXFN00jjuCVkf35fJX0bJHHOeKxSPpjhOBK4jjRCCq3faYW4eDiZwDcBLAMAAOtW5N/Fw2Jxc7l6tV1a6XWMKGKsgbBxV5SlUPb/iBLwF+LpuT9ToXf8RynAhcQRwngsulIHdfpuNeCvxcNifrci6XxQZxnK2CP2I5TgSuII4TwYYriIjcJiIvtUt179ro468FEblHRKZE5PklskEReUhEXmn/HIjax2ZBRHaJyM9E5MV2KfVftOVb7nwuZVn4hipIO+HxnwB8CMBBtCblHtzINayRbwJYXrt7F4CHVXUfgIfbf28FGgC+oKoHAdwK4DPt92Irns+FsvBbABwCcJuI3IpW1vnXVPU6ALNolYW/KTb6DvJ2AMdU9VVVraGVKn/7Bq9h1ajqIwCWFzzfjlbJMbCFSo9VdUJVn27/XgRwFK2q0C13PpeyLHyjFWQMwPiSv1cs1d1CjKjqRPv3swC4yGSTIyJ70MrYfhxb9HzWUhYehRvp64i2fOZbym8uInkA3wPwOVXtmHqzlc5nLWXhUWy0gpwGsGvJ31dCqe6kiIwCQPsnz2PepLTbOH0PwLdU9ftt8ZY9H2D9y8I3WkGeBLCv7V1IAvg4gAc2eA3rzQNolRwD3ZYebwJERNCqAj2qql9d8l9b7nxEZJuI9Ld/v1AWfhS/KQsHVnsuqrqh/wB8GMDLaD0j/uVGH3+Na/82gAkAdbSeae8EMISWt+cVAD8FMHi519nlubwbrcenIwB+1f734a14PgBuRqvs+wiA5wH8VVu+F8ATAI4B+HcAqTe7b081cZwI3Eh3nAhcQRwnAlcQx4nAFcRxInAFcZwIXEEcJwJXEMeJ4P8BDOlRG0/6AGEAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 216x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "\n",
    "# 读取数据\n",
    "reader = paddle.batch(\n",
    "    paddle.dataset.cifar.train10(),\n",
    "    batch_size=8) # 数据集读取器\n",
    "data = next(reader()) # 读取数据\n",
    "index = 4 # 批次索引\n",
    "\n",
    "# 读取图像\n",
    "image = np.array([x[0] for x in data]).astype(np.float32) # 读取图像数据，数据类型为float32\n",
    "image = image * 255 # 从[0,1]转换到[0,255]\n",
    "image = image[index].reshape((3, 32, 32)).transpose((1, 2, 0)).astype(np.uint8) # 数据格式从CHW转换为HWC，数据类型转换为uint8\n",
    "print('image shape:', image.shape)\n",
    "\n",
    "# 图像增强\n",
    "# sometimes = lambda aug: iaa.Sometimes(0.5, aug) # 随机进行图像增强\n",
    "# seq = iaa.Sequential([\n",
    "#     sometimes(iaa.CropAndPad(px=(-4, 4))),      # 随机裁剪填充像素\n",
    "#     iaa.Fliplr(0.5)])                           # 随机进行水平翻转\n",
    "# image = seq(image=image)\n",
    "\n",
    "# 读取标签\n",
    "label = np.array([x[1] for x in data]).astype(np.int64) # 读取标签数据，数据类型为int64\n",
    "vlist = [\"airplane\", \"automobile\", \"bird\", \"cat\", \"deer\", \"dog\", \"frog\", \"horse\", \"ship\", \"truck\"] # 标签名称列表\n",
    "print('label value:', vlist[label[index]])\n",
    "\n",
    "# 显示图像\n",
    "image = Image.fromarray(image)   # 转换图像格式\n",
    "image.save('./work/out/img.png') # 保存读取图像\n",
    "plt.figure(figsize=(3, 3))       # 设置显示大小\n",
    "plt.imshow(image)                # 设置显示图像\n",
    "plt.show()                       # 显示图像文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_data: image shape (128, 3, 32, 32), label shape:(128, 1)\n",
      "valid_data: image shape (128, 3, 32, 32), label shape:(128, 1)\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "\n",
    "# 训练数据增强\n",
    "def train_augment(images):\n",
    "    # 转换格式\n",
    "    images = images * 255 # 从[0,1]转换到[0,255]\n",
    "    images = images.transpose((0, 2, 3, 1)).astype(np.uint8) # 数据格式从BCHW转换为BHWC，数据类型转换为uint8\n",
    "    \n",
    "    # 增强图像\n",
    "    sometimes = lambda aug: iaa.Sometimes(0.5, aug) # 随机进行图像增强\n",
    "    seq = iaa.Sequential([\n",
    "        sometimes(iaa.CropAndPad(px=(-4, 4))),      # 随机裁剪填充像素\n",
    "        iaa.Fliplr(0.5)])                           # 随机进行水平翻转\n",
    "    images = seq(images=images)\n",
    "    \n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    images = (images/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    images = images.transpose((0, 3, 1, 2)).astype(np.float32) # 数据格式从BHWC转换为BCHW，数据类型转换为float32\n",
    "    \n",
    "    return images\n",
    "\n",
    "# 验证数据增强\n",
    "def valid_augment(images):\n",
    "    # 转换格式\n",
    "    images = images * 255 # 从[0,1]转换到[0,255]\n",
    "    images = images.transpose((0, 2, 3, 1)).astype(np.uint8) # 数据格式从BCHW转换为BHWC，数据类型转换为uint8\n",
    "    \n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    images = (images/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    images = images.transpose((0, 3, 1, 2)).astype(np.float32) # 数据格式从BHWC转换为BCHW，数据类型转换为float32\n",
    "    \n",
    "    return images\n",
    "\n",
    "# 读取训练数据\n",
    "train_reader = paddle.batch(\n",
    "    paddle.reader.shuffle(paddle.dataset.cifar.train10(), buf_size=50000),\n",
    "    batch_size=128) # 构造数据读取器\n",
    "train_data = next(train_reader()) # 读取训练数据\n",
    "\n",
    "train_image = np.array([x[0] for x in train_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取训练图像\n",
    "train_image = train_augment(train_image)                                                       # 训练图像增强\n",
    "train_label = np.array([x[1] for x in train_data]).reshape((-1, 1)).astype(np.int64)           # 读取训练标签\n",
    "print('train_data: image shape {}, label shape:{}'.format(train_image.shape, train_label.shape))\n",
    "\n",
    "# 读取验证数据\n",
    "valid_reader = paddle.batch(\n",
    "    paddle.dataset.cifar.test10(),\n",
    "    batch_size=128) # 构造数据读取器\n",
    "valid_data = next(valid_reader()) # 读取验证数据\n",
    "\n",
    "valid_image = np.array([x[0] for x in valid_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取验证图像\n",
    "valid_image = valid_augment(valid_image)                                                       # 验证图像增强\n",
    "valid_label = np.array([x[1] for x in valid_data]).reshape((-1, 1)).astype(np.int64)           # 读取验证标签\n",
    "print('valid_data: image shape {}, label shape:{}'.format(valid_image.shape, valid_label.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 模型设计"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm\n",
    "import math\n",
    "\n",
    "# 模组结构：输入维度，输出维度，滑动步长，基础长度, 队列长度\n",
    "group_arch = [(3, 256, 1, 2, 3), (256, 512, 2, 2, 3), (512, 1024, 2, 2, 3)]\n",
    "group_dim  = 1024 # 模组输出维度\n",
    "class_dim  = 10   # 类别数量维度\n",
    "\n",
    "# 卷积单元\n",
    "class ConvUnit(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, filter_size=3, stride=1, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化卷积单元，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim      - 输入维度\n",
    "            out_dim     - 输出维度\n",
    "            filter_size - 卷积大小\n",
    "            stride      - 滑动步长\n",
    "            act         - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ConvUnit, self).__init__()\n",
    "        \n",
    "        # 添加卷积\n",
    "        self.conv = Conv2D(\n",
    "            num_channels=in_dim,\n",
    "            num_filters=out_dim,\n",
    "            filter_size=filter_size,\n",
    "            stride=stride,\n",
    "            padding=(filter_size-1)//2,                       # 输出特征图大小不变\n",
    "            param_attr=fluid.initializer.MSRA(uniform=False), # 使用MARA 初始权重\n",
    "            bias_attr=False,                                  # 卷积输出没有偏置项\n",
    "            act=None)\n",
    "        \n",
    "        # 添加正则\n",
    "        self.norm = BatchNorm(\n",
    "            num_channels=out_dim,\n",
    "            param_attr=fluid.initializer.Constant(1.0), # 使用常量初始化权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),  # 使用常量初始化偏置\n",
    "            act=act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征进行卷积和正则\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 进行卷积\n",
    "        x = self.conv(x)\n",
    "        \n",
    "        # 进行正则\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "# 投影单元\n",
    "class ProjUnit(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, filter_size=1, stride=1, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化投影单元，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim      - 输入维度\n",
    "            out_dim     - 输出维度\n",
    "            filter_size - 卷积大小\n",
    "            stride      - 滑动步长\n",
    "            act         - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ProjUnit, self).__init__()\n",
    "        \n",
    "        # 添加池化\n",
    "        self.pool = Pool2D(\n",
    "            pool_size=filter_size,\n",
    "            pool_stride=stride,\n",
    "            pool_padding=0,\n",
    "            pool_type='avg')\n",
    "        \n",
    "        # 添加卷积\n",
    "        self.conv = Conv2D(\n",
    "            num_channels=in_dim,\n",
    "            num_filters=out_dim,\n",
    "            filter_size=1,\n",
    "            stride=1,\n",
    "            padding=0,\n",
    "            param_attr=fluid.initializer.MSRA(uniform=False), # 使用MARA 初始权重\n",
    "            bias_attr=False,                                  # 卷积输出没有偏置项\n",
    "            act=None)\n",
    "        \n",
    "        # 添加正则\n",
    "        self.norm = BatchNorm(\n",
    "            num_channels=out_dim,\n",
    "            param_attr=fluid.initializer.Constant(1.0), # 使用常量初始化权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),  # 使用常量初始化偏置\n",
    "            act=act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征进行池化卷积和正则\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 进行池化\n",
    "        x = self.pool(x)\n",
    "        \n",
    "        # 进行卷积\n",
    "        x = self.conv(x)\n",
    "        \n",
    "        # 进行正则\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "# 队列结构\n",
    "class SSRQueue(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, queues=2, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化队列结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长，1保持不变，2下采样\n",
    "            queues  - 队列长度，分割尺度为2^(n-1)\n",
    "            act     - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRQueue, self).__init__()\n",
    "        \n",
    "        # 添加队列变量\n",
    "        self.queues = queues # 队列长度\n",
    "        self.split_list = [] # 分割列表\n",
    "        \n",
    "        # 添加队列列表\n",
    "        self.queue_list = [] # 队列列表\n",
    "        for i in range(queues):\n",
    "            # 添加队列项目\n",
    "            queue_item = self.add_sublayer( # 构造队列项目\n",
    "                'queue_' + str(i),\n",
    "                ConvUnit(\n",
    "                    in_dim=(in_dim if i==0 else out_dim), # 每组队列项目除第一个外，in_dim=out_dim\n",
    "                    out_dim=out_dim,\n",
    "                    filter_size=3,\n",
    "                    stride=(stride if i==0 else 1), # 每组队列项目除第一块外，stride=1\n",
    "                    act=act))\n",
    "            self.queue_list.append(queue_item) # 添加队列项目\n",
    "            \n",
    "            # 计算输出维度\n",
    "            if i < (queues-1): # 如果不是最后一项\n",
    "                out_dim = out_dim//2 # 输出维度减半\n",
    "                self.split_list.append(out_dim) # 添加分割列表\n",
    "            \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 提取特征\n",
    "        x_list = [] # 队列输出列表\n",
    "        for i, queue_item in enumerate(self.queue_list):\n",
    "            if i < (self.queues-1): # 如果不是最后一项\n",
    "                x = queue_item(x) # 提取队列特征\n",
    "                x_item, x = fluid.layers.split(input=x, num_or_sections=[-1, self.split_list[i]], dim=1)\n",
    "                x_list.append(x_item) # 添加输出列表\n",
    "            else: # 否则不对特征分割\n",
    "                x = queue_item(x) # 提取队列特征\n",
    "                x_list.append(x) # 添加输出列表\n",
    "        \n",
    "        # 联结特征\n",
    "        x = fluid.layers.concat(input=x_list, axis=1) # 队列输出列表按通道维进行特征联结\n",
    "        \n",
    "        return x\n",
    "    \n",
    "# 基础结构\n",
    "class SSRBasic(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, queues=1, is_pass=True):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化基础结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长\n",
    "            queues  - 队列长度\n",
    "            is_pass - 是否直连\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRBasic, self).__init__()\n",
    "        \n",
    "        # 是否直连标识\n",
    "        self.is_pass = is_pass\n",
    "        \n",
    "        # 添加投影路径\n",
    "        self.proj = ProjUnit(in_dim=in_dim, out_dim=out_dim, filter_size=stride, stride=stride, act=None)\n",
    "        \n",
    "        # 添加卷积路径\n",
    "        if queues==1:\n",
    "            self.conv = ConvUnit(in_dim=in_dim, out_dim=out_dim, filter_size=3, stride=stride, act='relu')\n",
    "        else:\n",
    "            self.conv = SSRQueue(in_dim=in_dim, out_dim=out_dim, stride=stride, queues=queues, act='relu')\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "            y - 输出特征\n",
    "        \"\"\"\n",
    "        # 直连路径\n",
    "        if self.is_pass: # 是否直连\n",
    "            x_pass = x\n",
    "        else:            # 否则投影\n",
    "            x_pass = self.proj(x)\n",
    "        \n",
    "        # 卷积路径\n",
    "        x_conv = self.conv(x)\n",
    "        \n",
    "        # 输出特征\n",
    "        x = fluid.layers.elementwise_add(x=x_pass, y=x_conv, act=None) # 直连路径与卷积路径进行特征相加\n",
    "        y = x\n",
    "        \n",
    "        return x, y\n",
    "    \n",
    "# 模块结构\n",
    "class SSRBlock(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, basics=1, queues=1):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化模块结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长\n",
    "            basics  - 基础长度\n",
    "            queues  - 队列长度\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRBlock, self).__init__()\n",
    "        \n",
    "        # 添加模块列表\n",
    "        self.block_list = [] # 模块列表\n",
    "        for i in range(basics):\n",
    "            block_item = self.add_sublayer( # 构造模块项目\n",
    "                'block_' + str(i),\n",
    "                SSRBasic(\n",
    "                    in_dim=(in_dim if i==0 else out_dim), # 每组模块项目除第一块外，输入维度=输出维度\n",
    "                    out_dim=out_dim,\n",
    "                    stride=(stride if i==0 else 1), # 每组模块项目除第一块外，stride=1\n",
    "                    queues=queues,\n",
    "                    is_pass=(False if i==0 else True))) # 每组模块项目除第一块外，is_pass=True\n",
    "            self.block_list.append(block_item) # 添加模块项目\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x      - 输入特征\n",
    "        输出:\n",
    "            x      - 输出特征\n",
    "            y_list - 输出特征列表\n",
    "        \"\"\"\n",
    "        y_list = [] # 模块输出列表\n",
    "        for block_item in self.block_list:\n",
    "            x, y_item = block_item(x) # 提取模块特征\n",
    "            y_list.append(y_item) # 添加输出列表\n",
    "            \n",
    "        return x, y_list\n",
    "\n",
    "# 模组结构\n",
    "class SSRGroup(fluid.dygraph.Layer):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化模组结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRGroup, self).__init__()\n",
    "        \n",
    "        # 添加模组列表\n",
    "        self.group_list = [] # 模组列表\n",
    "        for i, block_arch in enumerate(group_arch):\n",
    "            group_item = self.add_sublayer( # 构造模组项目\n",
    "                'group_' + str(i),\n",
    "                SSRBlock(\n",
    "                    in_dim=block_arch[0],\n",
    "                    out_dim=block_arch[1],\n",
    "                    stride=block_arch[2],\n",
    "                    basics=block_arch[3],\n",
    "                    queues=block_arch[4]))\n",
    "            self.group_list.append(group_item) # 添加模组项目\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x      - 输入特征\n",
    "        输出:\n",
    "            x      - 输出特征\n",
    "            y_list - 输出特征列表\n",
    "        \"\"\"\n",
    "        y_list = [] # 模组输出列表\n",
    "        for group_item in self.group_list:\n",
    "            x, y_item = group_item(x) # 提取模组特征\n",
    "            y_list.append(y_item) # 添加输出列表\n",
    "            \n",
    "        return x, y_list\n",
    "        \n",
    "# 分割网络\n",
    "class SSRNet(fluid.dygraph.Layer):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化分割网络，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRNet, self).__init__()\n",
    "        \n",
    "        # 添加模组结构\n",
    "        self.backbone = SSRGroup() # 输出：N*C*H*W\n",
    "        \n",
    "        # 添加全连接层\n",
    "        self.pool = Pool2D(global_pooling=True, pool_type='avg') # 输出：N*C*1*1\n",
    "        \n",
    "        stdv = 1.0/(math.sqrt(group_dim)*1.0)                    # 设置均匀分布权重方差\n",
    "        self.fc = Linear(                                        # 输出：=N*10\n",
    "            input_dim=group_dim,\n",
    "            output_dim=class_dim,\n",
    "            param_attr=fluid.initializer.Uniform(-stdv, stdv),   # 使用均匀分布初始权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),           # 使用常量数值初始偏置\n",
    "            act='softmax')\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入图像进行分类\n",
    "        输入:\n",
    "            x - 输入图像\n",
    "        输出:\n",
    "            x - 预测结果\n",
    "        \"\"\"\n",
    "        # 提取特征\n",
    "        x, y_list = self.backbone(x)\n",
    "        \n",
    "        # 进行预测\n",
    "        x = self.pool(x)\n",
    "        x = fluid.layers.reshape(x, [x.shape[0], -1])\n",
    "        x = self.fc(x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tatol param: 28114954\n",
      "infer shape: [1, 10]\n"
     ]
    }
   ],
   "source": [
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.base import to_variable\n",
    "import numpy as np\n",
    "\n",
    "with fluid.dygraph.guard():\n",
    "    # 输入数据\n",
    "    x = np.random.randn(1, 3, 32, 32).astype(np.float32)\n",
    "    x = to_variable(x)\n",
    "    \n",
    "    # 进行预测\n",
    "    backbone = SSRNet() # 设置网络\n",
    "    \n",
    "    infer = backbone(x) # 进行预测\n",
    "    \n",
    "    # 显示结果\n",
    "    parameters = 0\n",
    "    for p in backbone.parameters():\n",
    "        parameters += np.prod(p.shape) # 统计参数\n",
    "    \n",
    "    print('tatol param:', parameters)\n",
    "    print('infer shape:', infer.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAD8CAYAAACYebj1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xl8VNX5+PHPM5kkk5VACGtAQLGyihBRqiLUDdz42tIqdbeW1qUurW2x37auVWp//dalVkWLu1Drroi4ixtIkB0EQgiSgJAFsi+znN8f52YyCVkmISHJ8Lxfr3nlzrnbuXPhPvcs91wxxqCUUkoBuDo7A0oppboODQpKKaWCNCgopZQK0qCglFIqSIOCUkqpIA0KSimlgjQoKKWUCtKgoJRSKkiDglJKqSB3Z2egMb179zZDhgzp7GwopVS3sXLlygJjTNrBbqdLBoUhQ4aQmZnZ2dlQSqluQ0R2tMd2tPpIKaVUkAYFpZRSQRoUlFJKBbXYpiAig4BngL6AAeYZYx5osIwADwBnAxXAFcaYr515lwN/dBa92xjzdFsy6vV6yc3Npaqqqi2rqxAej4f09HSio6M7OytKqS4mnIZmH/AbY8zXIpIErBSR94wxG0OWmQ4Mdz4nAI8AJ4hIL+A2IAMbUFaKyBvGmH2tzWhubi5JSUkMGTIEG4NUWxhjKCwsJDc3l6FDh3Z2dpRSXUyL1UfGmN21d/3GmFJgEzCwwWIzgGeMtQxIEZH+wFnAe8aYIicQvAdMa0tGq6qqSE1N1YBwkESE1NRULXEppRrVqjYFERkCHAcsbzBrILAz5Huuk9ZUeptoQGgf+jsqpZoSdlAQkUTgZeAmY0xJe2dERGaLSKaIZObn57dpG3tKqiit8rZzzpRS6vARVlAQkWhsQHjeGPNKI4vkAYNCvqc7aU2lH8AYM88Yk2GMyUhLa9tDefml1ZRV+9q0bnMKCwsZN24c48aNo1+/fgwcODD4vaamJqxtXHnllWzevDnsfT7xxBPcdNNNbc2yUkq1STi9jwT4N7DJGPN/TSz2BnC9iCzENjQXG2N2i8gS4B4R6eksdyZwazvku0nGtP82U1NTWb16NQC33347iYmJ3HLLLQ32azDG4HI1HmeffPLJ9s+YUkq1s3BKCicBlwI/EJHVzudsEfmliPzSWeZtIBvIAh4HrgUwxhQBdwErnM+dTlqHONQ15VlZWYwcOZKLL76YUaNGsXv3bmbPnk1GRgajRo3izjvvDC578skns3r1anw+HykpKcyZM4djjz2WSZMmsXfv3mb3s337dqZOncrYsWM544wzyM3NBWDhwoWMHj2aY489lqlTpwKwbt06jj/+eMaNG8fYsWPJzs7uuB9AKRVxWiwpGGM+o4XrrTHGANc1MW8+ML9NuWvCHW9uYOOuA5s1Kmp8uF0uYtytfyZv5IBkbjtvVKvX++abb3jmmWfIyMgAYO7cufTq1Qufz8fUqVOZOXMmI0eOrLdOcXExp556KnPnzuXXv/418+fPZ86cOU3u49prr+Xqq6/m4osvZt68edx000289NJL3HHHHXz88cf07duX/fv3A/Cvf/2LW265hQsvvJDq6mpMRxSdlFIRS59oPkhHHnlkMCAALFiwgPHjxzN+/Hg2bdrExo0bD1gnLi6O6dOnAzBhwgRycnKa3cfy5cu56KKLALjsssv49NNPATjppJO47LLLeOKJJwgEAgB8//vf5+677+a+++5j586deDye9jhMpdRhokuOktqSpu7oN+4qoUecm4E94w9ZXhISEoLTW7du5YEHHuCrr74iJSWFSy65pNHnAWJiYoLTUVFR+Hxtaxx//PHHWb58OW+99Rbjx49n1apVXHrppUyaNIlFixYxbdo05s+fz+TJk9u0faXU4SeySgpiH5vuLCUlJSQlJZGcnMzu3btZsmRJu2z3xBNP5MUXXwTgueeeC17ks7OzOfHEE7nrrrvo2bMneXl5ZGdnc9RRR3HjjTdy7rnnsnbt2nbJg1Lq8NAtSwpN6exHssaPH8/IkSM55phjOOKIIzjppJPaZbsPP/wwV111Fffeey99+/YN9mS6+eab2b59O8YYzjzzTEaPHs3dd9/NggULiI6OZsCAAdx+++3tkgel1OFBumJDZEZGhmn4kp1NmzYxYsSIZtfbtLuExFg3g3oduuqj7iqc31Mp1X2IyEpjTEbLSzYvoqqPOrukoJRS3V1EBQWNCkopdXAiKyjQuQ3NSinV3UVUUJDO7n6klFLdXEQFBQCjUUEppdosooKCNikopdTBiaig0JFRYerUqQc8jHb//fdzzTXXNLteYmIiALt27WLmzJmNLjNlyhQadsFtLl0ppTpKZAUFoLjSS3kHvFNh1qxZLFy4sF7awoULmTVrVljrDxgwgJdeeqnd86WUUu0pooJCbUFhW35Zu2975syZLFq0KPhSnZycHHbt2sUpp5xCWVkZp512GuPHj2fMmDG8/vrrB6yfk5PD6NGjAaisrOSiiy5ixIgRXHDBBVRWVra4/wULFjBmzBhGjx7N73//ewD8fj9XXHEFo0ePZsyYMfzjH/8A4MEHH2TkyJGMHTs2OJCeUkqFo3sOc7F4Dny37oDkgV4fzmChENvKQ+s3BqbPbXJ2r169mDhxIosXL2bGjBksXLiQn/zkJ4gIHo+HV199leTkZAoKCjjxxBM5//zzm3wX8iOPPEJ8fDybNm1i7dq1jB8/vtms7dq1i9///vesXLmSnj17cuaZZ/Laa68xaNAg8vLyWL9+PUBw+Oy5c+eyfft2YmNjg2lKKRWOCCspdGxTc2gVUmjVkTGGP/zhD4wdO5bTTz+dvLw89uzZ0+R2li5dyiWXXALA2LFjGTt2bLP7XbFiBVOmTCEtLQ23283FF1/M0qVLGTZsGNnZ2fzqV7/inXfeITk5ObjNiy++mOeeew63u3vGfaVU5+ieV4wm7ujz9pZRUWPbE8amp7T7bmfMmMHNN9/M119/TUVFBRMmTADg+eefJz8/n5UrVxIdHc2QIUMaHTK7vfXs2ZM1a9awZMkSHn30UV588UXmz5/PokWLWLp0KW+++SZ/+ctfWLdunQYHpVRYIqqk0NESExOZOnUqV111Vb0G5uLiYvr06UN0dDQfffQRO3bsaHY7kydP5oUXXgBg/fr1LQ5vPXHiRD755BMKCgrw+/0sWLCAU089lYKCAgKBAD/60Y+4++67+frrrwkEAuzcuZOpU6fy17/+leLiYsrK2r+NRSkVmSLq9vFQjPg6a9YsLrjggno9kS6++GLOO+88xowZQ0ZGBsccc0yz27jmmmu48sorGTFiBCNGjAiWOJrSv39/5s6dy9SpUzHGcM455zBjxgzWrFnDlVdeGXzr2r333ovf7+eSSy6huLgYYww33HADKSntX2pSSkWmFofOFpH5wLnAXmPM6Ebm/xa42PnqBkYAacaYIhHJAUoBP+ALd1jXtg6dvWVPKVVeP9Ax1UeRRIfOViqyHMqhs58CpjU10xjzN2PMOGPMOOBW4BNjTFHIIlOd+QedWaWUUh2rxaBgjFkKFLW0nGMWsOCgcnQQuuD7gpRSqltpt4ZmEYnHliheDkk2wLsislJEZrew/mwRyRSRzPz8/EaXabnNQKNCOLri2/aUUl1De/Y+Og/4vEHV0cnGmPHAdOA6EZnc1MrGmHnGmAxjTEZaWtoB8z0eD4WFhc1e0PRS1zJjDIWFhXg8ns7OilKqC2rP3kcX0aDqyBiT5/zdKyKvAhOBpW3ZeHp6Orm5uTRVigD4rrgKX8CGho0lcTTxQPFhz+PxkJ6e3tnZUEp1Qe0SFESkB3AqcElIWgLgMsaUOtNnAne2dR/R0dEMHTq02WWuuvcDdhfbh8bevP5kxqT3aOvulFLqsNRiUBCRBcAUoLeI5AK3AdEAxphHncUuAN41xpSHrNoXeNUZ/8cNvGCMeaf9st5Svg/VnpRSKnK0GBSMMS2ODW2MeQrbdTU0LRs4tq0ZU0opdejpMBdKKaWCIiooaI2RUkodnIgKCkoppQ5OxAYFbWhWSqnWi9igoJRSqvU0KCillAqKqKAQ+k7kjn41p1JKRaKICgpKKaUOTkQFhXt/OCY4bXR4PKWUarWICgqTj64bXVVHh1ZKqdaLqKAQSoOCUkq1XsQGhYBGBaWUarWIDQp7SqooqfJ2djaUUqpbiayg4PfRh314qGb2sys5ae6HnZ0jpZTqViIrKOxezVee6zjRtRGA0ipfJ2dIKaW6l8gKCnE9AUihvIUFlVJKNSYyg4KUdXJGlFKqe4qsoOCx72TWoKCUUm3TYlAQkfkisldE1jcxf4qIFIvIaufz55B500Rks4hkicic9sx4o1xRlJh4emj1kVJKtUk4JYWngGktLPOpMWac87kTQESigIeB6cBIYJaIjDyYzIZjv0mgh2hQUEqptmgxKBhjlgJFbdj2RCDLGJNtjKkBFgIz2rCdVtlPIilo9ZFSSrVFe7UpTBKRNSKyWERGOWkDgZ0hy+Q6aR2q2CRom4JSSrVRewSFr4EjjDHHAg8Br7VlIyIyW0QyRSQzPz+/zZkpJlHbFJRSqo0OOigYY0qMMWXO9NtAtIj0BvKAQSGLpjtpTW1nnjEmwxiTkZaW1tRiLdI2BaWUaruDDgoi0k+cV56JyERnm4XACmC4iAwVkRjgIuCNg91fS+raFHRAPKWUai13SwuIyAJgCtBbRHKB24BoAGPMo8BM4BoR8QGVwEXGGAP4ROR6YAkQBcw3xmzokKMIUWwScEuARCopI76jd6eUUhGlxaBgjJnVwvx/Av9sYt7bwNtty1rbzDr1WPjiBVKknDKjQUEppVojsp5oBnwx9qlmbWxWSqnWi7igUBOTAkAP7ZaqlFKtFnFBwRudBEAyFZ2cE6WU6n4iLijUuGIBiKO6k3OilFLdT8QFhV49bJtCnNR0ck6UUqr7ibigMHxgH0BLCkop1RYRFxSItt1QPWhJQSmlWivygkJUNH5cxImWFJRSqrUiLyiIUC0e4rSkoJRSrRZ5QQGokVhtU1BKqTaIyKBQRSwe7X2klFKtFpFBoVpitKSglFJtEJFBoYpYbVNQSqk2iMigUCMe7X2klFJtEJFBweuK1ecUlFKqDSIyKNS4PNqmoJRSbRCRQSE+IVHbFJRSqg0iMiiMOqIfCVE19E2O7eysKKVUt9JiUBCR+SKyV0TWNzH/YhFZKyLrROQLETk2ZF6Ok75aRDLbM+PNiYpJIEFqCJhDtUellIoM4ZQUngKmNTN/O3CqMWYMcBcwr8H8qcaYccaYjLZlsQ2i44gOVGMCgUO2S6WUigTulhYwxiwVkSHNzP8i5OsyIP3gs3WQouNwEcAV8HZ2TpRSqltp7zaFnwGLQ74b4F0RWSkis9t5X01zhs+ONlWHbJdKKRUJWiwphEtEpmKDwskhyScbY/JEpA/wnoh8Y4xZ2sT6s4HZAIMHDz64zETHARBrtAeSUkq1RruUFERkLPAEMMMYU1ibbozJc/7uBV4FJja1DWPMPGNMhjEmIy0t7eAyVFtSCFQe3HaUUuowc9BBQUQGA68AlxpjtoSkJ4hIUu00cCbQaA+mdldbUghoSUEppVqjxeojEVkATAF6i0gucBsQDWCMeRT4M5AK/EtEAHxOT6O+wKtOmht4wRjzTgccw4GCbQpaUlBKqdYIp/fRrBbmXw1c3Uh6NnDsgWscArUlBWowxuAEJqWUUi2IyCeaa4NCHNX49Qk2pZQKW4QGBVt9FEcNfqNBQSmlwhWZQcFtxzyKxaslBaWUaoUIDQoeAGJFg4JSSrVGhAaF2pJCDTr8kVJKhS9Cg4JTUsCLT6OCUkqFLUKDQl2bQo1fg4JSSoUrMoOCKwq/uIkVL1c+uaKzc6OUUt1GZAYFwOeKJRYv33xX2tlZUUqpbiOCg0IMsfqeZqWUapWIDQpeiSEWfcmOUkq1RuQGBWKIFQ0KSinVGhEbFCTag0erj5RSqlUiNiik9kgmzuXje32TOjsrSinVbURsUJBoD/3iwavPKSilVNgiNijgjiXGePHqE81KKRW2CA4KHqKpwevTAfGUUipcERwUYok2NVp9pJRSrRDBQcGjQUEppVoprKAgIvNFZK+IrG9ivojIgyKSJSJrRWR8yLzLRWSr87m8vTLeIrcHd6AGr1+rj5RSKlzhlhSeAqY1M386MNz5zAYeARCRXsBtwAnAROA2EenZ1sy2ituD29To0NlKKdUKYQUFY8xSoKiZRWYAzxhrGZAiIv2Bs4D3jDFFxph9wHs0H1zajzsWd6Aar99g9D3NSikVlvZqUxgI7Az5nuukNZV+ABGZLSKZIpKZn59/8Dlye3AbL0JAq5CUUipMXaah2RgzzxiTYYzJSEtLO/gNOi/aicGnVUhKKRWm9goKecCgkO/pTlpT6R0v+EpOfVZBKaXC1V5B4Q3gMqcX0olAsTFmN7AEOFNEejoNzGc6aR1PX8mplFKt5g5nIRFZAEwBeotILrZHUTSAMeZR4G3gbCALqACudOYVichdQO07Me80xjTXYN1+aksK4tXqI6WUClNYQcEYM6uF+Qa4rol584H5rc/aQQopKWj1kVJKhafLNDS3O6ek4NHqI6WUClvkBoXouoZmrT5SSqnwRG5QCGlT0OojpZQKT+QHBWqo9vk7OTNKKdU9RHBQqGtoLq70dnJmlFKqe4jcoBAdD0AiVeyv0KCglFLhiNygkNQfgL5SxH4tKSilVFgiNyjExGPietLfVcT+iprOzo1SSnULkRsUAElOZ3DUPq0+UkqpMEV0UCB5AANcWn2klFLhiuyg0GMgfUyhVh8ppVSYIjsoJA+ghykBX2Vn50QppbqFCA8K6QCkeAs7OSNKKdU9RHhQGABAL//eTs6IUkp1DxEeFOzroFN8BZ2cEaWU6h4iOyjE9QQgPlDayRlRSqnuIbKDgqcHAAmBsk7OiFJKdQ+RHRSi3FRJHPEaFJRSKixhBQURmSYim0UkS0TmNDL/HyKy2vlsEZH9IfP8IfPeaM/Mh6MiKklLCkopFaYW39EsIlHAw8AZQC6wQkTeMMZsrF3GGHNzyPK/Ao4L2USlMWZc+2W5dSpdibiqSjpr90op1a2EU1KYCGQZY7KNMTXAQmBGM8vPAha0R+baw+7qWJIoY8icRZRW6XAXSinVnHCCwkBgZ8j3XCftACJyBDAU+DAk2SMimSKyTET+p805baN9gXiSqQDg7XW7D/XulVKqW2mx+qiVLgJeMsaEvv/yCGNMnogMAz4UkXXGmG0NVxSR2cBsgMGDB7dbhopJYITsACDZE91u21VKqUgUTkkhDxgU8j3dSWvMRTSoOjLG5Dl/s4GPqd/eELrcPGNMhjEmIy0tLYxshafExJNMOQC+gGm37SqlVCQKJyisAIaLyFARicFe+A/oRSQixwA9gS9D0nqKSKwz3Rs4CdjYcN2OVGwSSJZKXAS47Y0N3P7GBqq8/pZXVEqpw1CLQcEY4wOuB5YAm4AXjTEbROROETk/ZNGLgIXGmNDb8RFApoisAT4C5ob2WjoUSrDvak6igqLyGp76Ioenv8g5lFlQSqluI6w2BWPM28DbDdL+3OD77Y2s9wUw5iDyd9CKTQIAPaScYpMIQKWWFJRSqlGR/UQzUIINCrXtCgA94rTBWSmlGhPxQSG0pFDLaHuzUko1KuKDQpmrtqRQEUx7btkOSqq8+LU3klJK1RPxQaEgkARAWt1wTGQXlDP29neZu3hTZ2VLKaW6pIgPCvmmB2XGw1D57oB5b6zZ1Qk5Ukqprivig8Kx6SlsN/0aDQqCdEKOlFKq64r4oPDMVSeQftRYhsqB4x7ll1V3Qo6UUqrrivig0CM+mp6DRpAu+cRQf5RUbWhWSqn6Ij4oAJB6FFFi+MWYAw93X3lNJ2RIKaW6psMkKBwJwG+88zhS6o/ld9xd77F0S35n5EoppbqcwyMo9BkJg06Ab7/kkqj3D5i9fHthJ2RKKaW6nsMjKETHwc/ehaNO5/SoVUD9tgS36/D4GZRSqiWH19Xw6GkMkr0Mb1CFFB2lXVOVUgoOt6Bw1GkAnOiqP3q3O+rw+hmUUqoph9fVMHkgBqG3FNdLdru0pKCUUnC4BQVXFAFPT1IpqZe8aud+Zjz8OdU+fc+CUurwFtZLdiKJKzGNXuWl9dIWrbVPO39bWMHwvkmdkS2llOoSDq+SAiDxqXy/fxPzRMgpKGfJhgPHSVJKqcPBYRcUSEglxZQ0OsslMPXvH/OLZ1ce4kwppVTXEFZQEJFpIrJZRLJEZE4j868QkXwRWe18rg6Zd7mIbHU+l7dn5tskvjeUFzQ6yx8w+lY2pdRhrcU2BRGJAh4GzgBygRUi8oYxZmODRf9jjLm+wbq9gNuADOwTYyuddfe1S+7bIj4VKosQApgGMbHGH+ikTCmlVNcQTklhIpBljMk2xtQAC4EZYW7/LOA9Y0yREwjeA6a1LavtJKE3mAD3Tht0wCyfv66YYIzh86wCzvrHUqq82itJKXV4CCcoDAR2hnzPddIa+pGIrBWRl0Sk9oob7rqIyGwRyRSRzPz8DhygLr43AD8ZGUdMg4fWvCElhYCBexdvYvOeUjbsqv9cg1JKRar2amh+ExhijBmLLQ083doNGGPmGWMyjDEZaWlp7ZStRsT3AsBVWcTKP51eb9afXt8QnPYFAqSnxAOQt7+q4/KjlFJdSDhBIQ8IrWtJd9KCjDGFxpja15g9AUwId91DLsGWFKgoIMZd//A37a7rleQPGKKd+T6nBBEIGC6b/xWfbtWhtpVSkSmcoLACGC4iQ0UkBrgIeCN0AREJ7fl/PrDJmV4CnCkiPUWkJ3Cmk9Z5nOojyguIkqaHt/AFDCWV9k1tVV4bFMpqfCzdks+1z33d4dlUSqnO0GLvI2OMT0Sux17Mo4D5xpgNInInkGmMeQO4QUTOB3xAEXCFs26RiNyFDSwAdxpjijrgOMKXkAYSBcW5RDUz5tHY298NTtc2NNd2V9Veq0qpSBXWMBfGmLeBtxuk/Tlk+lbg1ibWnQ/MP4g8ti93DPQ8Agq3IiKc41pGZuBo9tCryVWqnDGRat/pbPRhBqVUhDr8nmgGSB0OhdugooiHYx5kbvoXzS5eW31U27agIUEpFakOz6DQ2wkK360FYErvxoe9qJVbVEFZtQ9fsKRQf/4767+jrNrXIVlVSqlD6fAMCqlHga8Stth2AyncFpz10xMGH7D4K6vyOPuBT4MX/sqQh9nW5Rbzy+dWMvq2JVTUaGBQSnVvh29QANj4mv1blE2vOPtTpPeMa3SVb4sqOPMfS4Pfl2cXAnDePz8Lpi1ep6OrKqW6t8MzKPQ+2v4tcR6Z8Ffz8ezhfD7nB5SHWQ300eYDn1Vwt+Jdz9pYrZTqig7PoJDUF46dZaedUkNy+Q4GpsRRXh3eOEePfrLtgLQbF65m5iPNN1oDFJXXMPTWt3l22Y7w86yUUofA4RkUAM5/CE6dA+f+w34vzALgqpOGMqhX41VI4cjcsY/l2YV8V9z00Bj5pfbh76c+397m/SilVEc4fINCVDRMvRWGnGKH0961CoDB6x7k0wvj+PiWKQD88ZwRTW6iqSqgC+ct4+wHP62XtiKniPV5dmC92mfmSqrqqqoe/WQbq749cETx3cWVLPzq27APSymlDsbhGxRqicCQk2H7Uqgogo/vhcx/M6R3Ajlzz+HqU4Y1uer0Bz5tcl5ReQ0AefsrufftTfz40S8596HPKKv2Ue2zzzvkl1aTtbcMYwxzF3/DBf86sOrp3Ac/Y84r69qly+ve0io+2rw3+L3GF+DjkO8Hq8rrD77vWinVPWlQABg62TY6r3vJfv9uXb3ZQ2Q3PSnhvGMH1Ev/5rvSZje7PLuQk+Z+yGNLs4Npo29bwrb8suD32c9mBoNEYwqd4DL6tiUHNIJv2FXMra+sIxAIr9H6p48v58onV/DhN3vwBwx/f28zVzy5gsyc5kce2V1cGdb273tnM9e98DXLnJ5ZSqnuR4MCwJDJ9u+nf7d/C7ZATUVw9kLPXP4U/RzJnrBGBQn6/ctrG02/e9Gm4HRVjb9eKeCEe95vcns799Xl6cUVOznnwc9Y8NW3vL9pD4vX1b9DL67w8mLmTobMWcT+iho+2ryXrL02GF31VCYPfLCVxz6xwaq2jaMxH2zaw6R7P+Sjb+pKFHtKqvjNi2uorKnfKL+n1Laj7HW2V+X18/uX1pK1t4yxty/hy20HHyyWZxeGHQRba9f+SgrLmv4tlDocaFAA+4Rz2ggoc54zMAHY61y4ywvpZ/I5PzWXX59xdL3SwtwfjglONyxFAOQUVhyQBvUvwr6A4YuQi+Wekmp2FlVw08JVPNagh9PN/1kTnL7rrbq3oc5+diXXPF83cmtOQTnH3/M+v3vJBqXpD3zKlU+uINRrq+pGMPc2c5Fdm2vbQb4Oae849W8f8fLXubwdEogWfvUtm52SU+0Agu9v2sN/Mndy9gOfUlLl49+fZdOcjbtKGu3VVeujzXu5cN4ynvoip9nttNX3537Iifd+0CHbVqq7aN2tb6QSgbP+As/9EBL72eDwwk9g1kLw2wu4e38OqVEVPDTrOAakeDj+iF6cPrIvc16xVU0PzTqOm04fzml//6RVu95bWs0NC1bVS7vzrY28t3EPr63eVS990+4S/AHD3tIq4mOjKG1QnTRkzqJG97G7kZ5QgZBG8vJqH/OWbiNg4JenHsmLmTsZ3ieR9btKgsGg9v3V2wvKg2NBVdT42FlUwSn3fVRv27VBISHWXW/d1ITYessZY3hu+bf8z7gBLMsu4ufPZALws5OHEh3l4vXVeazLLeaP547E6w/wifNsSE5heb3t1PgCFJXX0K+Hp9HjD8dnWwsA8PqbL4XsK6/hhoWr+HRrAbMmDubekBuDZ7/M4U+vb2DL3dMPeFeHUt2FBoVaR50G594PgyfBs/8Dpbvh5Z/BxJ/XLbNrFRz5A26d3niPpCPTEsmZe06TF+dwvbdxT5PzjvzD203Oa43cfXXtBPOWZrO9wF5o5y7+pt5y86L/TnrUOB77BGaOT+ehD7OC85ZuLaj3trpaf359Az/JGES1t35byX8yd/JNUIfnAAAZgUlEQVTXmWOD37/YVsifXlvP+xv38MmWuocBcwrKGd43iRsXrgYgLiaKGn8gWEJwu1zsr6jhpLkf0iMumhOGpfLqqjy2/mU60VEHXoz3lFRxwj0f8N9fTuL4IfVHwy0sq2bC3U1X2TX05tpdfOoEkAVffVsvKNT+dqVVXlITYxtdv62+LawgyeOmZ0JMu25XqYb0diZUxpXQ5xi45gu48HnYvwPe/SNEJ9j5u1YdsMo5Y/rzp3NHdmi2erfzBaah2oDQkBsfp7u+ZrrrKwDO+MdS3lhTV3ppLniNveNd/vjaugPSa9sDNuwqZl+FbUQPDQi1+wltr3jow6xg+wfA8u2FZNz9PuU1fnYVV/GqUxVWXOllyJxFDJmziCqvP9hluLbh+6nPcwBbQnni02z++Nq6FgNCZY2f8x76jMycIgIBc0AjujGGPSW2JFZbC1daVb8EFwgYKmv8FJRVc/QfF7OiiYb90ipvcFsA1T4/OwrL+a64isl/+4gz/tFyKbS4wstX2zv3lSWqe5OuONxCRkaGyczM7OxswH+vgA2vQlxPW60U8MEvlkKMfXcznz9gn3E47pJ6q133zDLe37ibamLwUE0MXh64YioAVz29gvsvHMcZI/vy2qpd3PXWxuAAe3+5YDRb95TVqzN/4rIMTh/Zlwl3vRfsidQaibHuNndnTZd8Pou9kQKTTEb1I0D4w3iE6kchiVJJlknn2EEprNm5v03bacn/jBtQr8pt4pBe/ObMo/l8WyEPfrAVgOF9EtlRWBGs0mpMztxzAMjMKaKovIbZz64E4Jh+SQf0OJs0LJUvswsZ2T+ZjSGvc/3TuSPZWVTBU1/kcPqIvry/aQ///OlxXP/CKk47pg/zLssgYEy9ks3MR74gc8c+tt1zNlEu4XcvreHFzNx6+/vvLycxsn8yCbFujDEEDPVeFnXJE8v5LKuA9XecRWKsmyqvH5dIi9VZ2fllVHkDeKJdDEtLbHZZ1TWJyEpjTMZBb0eDQjOqiuHhE2HC5TDoBFutNPEXcPZ9sGWJbXcAuN02xuL32faJ167F7N+BXPUOvH495GbC7I/BWwHxDV7ms/cbtq1fzofuU/j5ZPtMRN7+Sr4trKBPcixHNvgPet3zX5O7r4LJR6fx88nDgm+I600xBfSot+yI/snccf4oPtuaz4Mh1T4Ay/9wGvs++zfVy57gPt+FBIZM4csGd8ETZRMvxt4FwAlV/6z3IqLjBqew6tv6F/d47F1uBfXr9l+N+TOpFDO55oEmfuiu5Zh+SUw9pg+PfNx0o3d7ue9HY/ldg15qx/RLondiLJ9lFTS53q3Tj6FXQgy/fWktvznjaEYNTGb0wB5M/IttKF90w8kkxLiZ8v8+ZtKwVBbMPpEqrx9PdBRPfb6d7QXl3DFjdHB7oVWey249Lez2mU27S3h22Q7umjH6gDcZ+gOGvH2VDE6ND2tbHaXa58fnN8E2rsoaP3ExUWGt6w8Y1uUVM25QSkdmsV1oUDhUAgFwOXdZi+fA8kdg8u9g+aNQ7dwZ/jYbPD3gmfPBHQu710JFAdySBf8+A/Zth++dbZ9/uGE1FO+073NIGQTzp0FlEdywCir3w9dPw4nXQdrRdfsv3QU90hvN3udZBRxZmkm/1y8k56ynmPfdUYwe0INRA5I5tsE/5FtfWUe/ZA83nj4cCrLgnxMA+LTH+Rx37ZMkxrrJzCniiU+3M31MPyYUv0v6RzcBsOTYB1nlOYFXV+Vy5sh+zJl+DKNuq3vd9uzJw5ix7npKKyp5Z8Lj/Gh8On9+Yz1nJn/LNVm/BGD9T7/mD+/u5rviKvokx1Ja5WOH00PrngvGMHNCOkf/cTFJHjdRLuGotETGpPfgSafa592bJwdHqk2IieLxyzP4eHM+85Y236upNdwuCb43IxL1SYoNdhk+6ahUjIHoKFe9Krxj+iUxqFc8bpdQXuOnvNpHr4QYYt0uKmr8xEXbC2rAGBavtz32BvTwcMKw1GCVnYjw4Td7Ka70MvnoNJI9boorvSTEuIOllto3GbpcQnSUgLEvsKotARWV1+CJjiIlPhpjbIeF0iovcdFRxLpdiPOO9WAoColJ4nwRgTfX7KLaF+DHE9LJ21/JF9sKmXx0Gh63i5zCcvome+iX7ME4x4SB8hofBWU17Cgsp6CshlOPTqN3YiwGQ40vQKy7LqiI1O16y94yUhNiSImPRpB686D+C7pCL73GmZPsieb280eFf0JDHNKgICLTgAew72h+whgzt8H8XwNXY9/RnA9cZYzZ4czzA7WVy98aY85vaX9dKiiE8lbCv8+0L+dJPcqOnfTK1baXUt7XsPS++suf83dY9Jv6adPvg6V/g/J88KSAyw0VhTBlji197Poa3HFwzeeQeiR89Tgs/h3MeBiKsuGU30Bxnp0HULnP7mPDK9BvDMxeWhfEalUUHVhC+fJfsORWSB4IPQbBz5ZwgKV/gw/vttPHXQJn/x2i6+4gvf4A5dU+UuJj7HMdfz0C/DVw4xroOcT+q3/+x5D1nl3h4pdh+On1dlFS5SXW7Qr+JzPGBP+z1/L5A1T5AiTGusnaW0q/HnEkOnd9Xn+A9XnFHDe4JyVVXqq8fqJdLvaWVlPptRewrXtL8bijGNY7niH7l+MaejL+gB/fK9cSdcpNRA0cx5Y9ZaTER9M32cPOogqeW76DUQN6kORx8+jH2xjcK57rph7FK6vy+MExfSit8uJ2uUiMdfPNdyXEuF1s2l3KiP5JxES5+CyrgNx9lZw7tj+x0VG8sXoXvqIdFBUVUJgwnJIqL6VVPi44biCpCTF4/QEWrfuOo/oksLekmmumHMmOwgrW7yqmssbP8u1FDO4VT6XXT35pNSnx0QxJTcBg79ZTE2IYmBJH5g7bW2zcoBRWO9V0Sc5vVVrtY1jvBLKdNqSx6T2IiXLh9QdY43Q9BhiSGo8nOgp/wFBa5aPa56dHXDQuEQz2IugSwSVQVO6loKyaUQOSKa704hJ7IQwYQ1FZDeU1foalJWAMJHnc7NpfiSc6iugoF1EuwesPUFHjJybKZS+gzrYFOwyM2yW4XYI42032ROP1B4IPfNZeTOtdYBtc0vL2204V/Xt48AUM+aXVDEmNJ8btYsueMqJcQp+kWFxSF0iio1zsLq4k2RPN3tJqBqbUjYfmctl7tbr9OXnA9vTrGR9NfIw7OM9QPzA0/Pddlw69EmJ44/qTG53fkkMWFEQkCtgCnAHkAiuAWcaYjSHLTAWWG2MqROQaYIox5kJnXpkxplWVlF02KICtIireCckD7PMM96bbHks7PoeBEyDXeR5AouwyxTsP3EZCmr2A71oFl78F78yxQcVbbgPNskegzwg473549Rewu+75BFKPsoP3DTnFBpO9GwGBXkNt0Pjx0zByhg0uRdvtMOGPT4WZ82263wdfPAif32/bSYZOhjUL4aa1tu1EBAJ+KM6FT+6DLe/A4BPhm7cg42e2i26fkTDpOrvcO3MgsS+kH29LSgCn/dkGr9UL4LVfwpRb7fAhP/gjTP6tXaa8EPbnQFQs5H8DY2ba9L2bYOPrMOl6iG1l3XZxHvQYCPtybFBqaPtSePo8G5hryuGDO+C4S2HGPxvfnjHw/EwYdcEB7UZt8sJF9nzd1PhDjUodjPYKCuF0SZ0IZBljsp0dLwRmAMGgYIwJ7ai+DGiH/0FdVJTbXoBrpR8POZ9CryPhkpfhwePAVwPDz7B372AvzAVbYcyPYdsHcNnr9sLqrbSN1pN/Cx/cCXEp9mLaayi8dg38axJg4JhzbXCJ723XHzoZ9n8LSf3tRW3j6zYYvPwze2e/7r/2Ig51DeTrXrJB4cuH7MUQ7LqpR0FNKdw3FE64BkacC6/MrnvXRP9xcOFz8NJVsPY/UFNmG9d7pMOq52CrbdMg5Qj7t98Ym37STbaKrd9YW9229kXYtbrud3t1NmS9D24P+KpsyanPCPjoHti5HDa8Zttulj9mj/f4q22A+mYRnHWP/a1CrX/Z5rE2AF3+pl0PbPBa+x/7uwBsejP4Kla2vAPlBbb6Lyq6/jaLc20efdUHHxSMgdyvbCCvKrb7U6oLCicoDARCb3dzgROaWf5nwOKQ7x4RycRWLc01xrzW6lx2ZRc+Z+/c+46C2CR7h1u1H75/gw0KSf3hlFtgzzo4/Q7we+uqYGp7MQ07FYaFPEl77EVw1Onw/m2w9T37/ERiGlSV2IvYqB/a4FTr/IfsHf7UP9geU8W59m495zPY9qFdZtMb8H+joCQXjp5uBwEc82MbXGp99Zj99BpmA9E3b9kgIGIDSm2QqyiEFy+zweGU38DmxfYOeOAEe/wvXQlfP2NLOFP/15a3B02ETW/ZvJXusRfbfmMhJsFeJN+ZU5ePMT+2F/BnfwjisvlY9i9bAgAbQKbNtSWZI39gL/Cb3rDzPnZqNrcssUE07RjInA9v32LTXdE2iINtu1n2MPztSEhOh/MesNVbte1ItaW+3BXgrbL76TXMBsL8zfY3GnQ8lOyywa1hFV2o/Tvs7wa2NLT/W1s6nO7k1xj7OyvVycKpPpoJTDPGXO18vxQ4wRhzfSPLXgJcD5xqjKl20gYaY/JEZBjwIXCaMeaAbh0iMhuYDTB48OAJO3ZEwAtoKvfbHkfJBw6BEbbWXix2rbbDdsQkQO5K29A9/lJY+ZSdf/od9nmM2jvVmnK4ZwAMmwL7dkB6hn3HhN9rSw+1va2qS+G+YbaEU1EI0fHw8w9tFU9NOeQ4bSApR9gG7NoL+M8/tMGiaDs8cpJtQ6kusfu/aa39u28H7PjCloZyM22g/e/lsPltuGgBFG2zJYaJP7frL/kD9BgMxbUBTQDjlMi22CR3nH0P95gfQ9YH9jc58jRbRbb4t7Y32U9fhMcm24C1Z6Otxho5w5bGjp1lA9Kyf9X/ffuPsx0GTMCel9EzbdDqdaTtYRblhuXz7Hk/6ca6c1dbkgE45/9sKapgC9yy1Zas1v0XLnkFkvvX7SvnM8j+xJZ4hp4S/r8BdVg6lG0Kk4DbjTFnOd9vBTDG3NtgudOBh7ABodHxmEXkKeAtY8xLze2zS7cpdDcVRfbC+9U8+N70xuvaS3bZdgFx1Q9AFUUQkwhu5ynatS/ai37PI2x6U3X+2z6EZy+w038uApfTU2PTm7aqK3W4bUOobSxvTHmhbacZcV79PAUCMO9Ue9d+/NW2V9egiTav+3LgmRm2lJX1Xl1gSOwLl74GfUfa5R6fahvNQxu9q0tt77LNi2yVW74z9lXqcCi0zzgw4nxbIpEo2xFg2SO2Wi6xj+1hNul6+/vWlkpOvBam3WuD5uvX2yDniob+x8IO593eP/iTbdD3VdmANXgSXPAoLHsUPr6nLn9jL7JtMlnv20C3Z4MNtnmZMOA42+tNHdYOZVBwYxuaTwPysA3NPzXGbAhZ5jjgJWyJYmtIek+gwhhTLSK9gS+BGaGN1I3RoBAB1r9s21bGzWr/bX+73PacumiBfbVqqLK9tkfW4z+wJY6AzzaUxyaFv31j4MO77Ki5p/7edgzofbRtP3pogr1z/+FjdtmA3wbTFy+1QQ/sRXrgBFjxhG1PWfdfGzSOu8S2Le1cDogt9QS89u+Mh+0Ff8OrNiiV5MLYC2012bJHbM+2mERbnZfQB8r3whEn2cAZk+iUDhNtVWThNluVmDK4XX5u1T0c6i6pZwP3Y7ukzjfG/EVE7gQyjTFviMj7wBigdtjMb40x54vI94HHgAB2SI37jTH/bml/GhTUQWuPOvryAohNrispga0SjI5r/M68YKutGus7GhB4+lwbAJIGwAWPwNBTbTXYisdtqc1bZad/9O+63leb3rLBpM8IW9VXu++3boaVT9uL/raPIKmfHZ/r6Gm2S3FRtq1+K9hi950yCKb+0VZzJaTVPSuTPEDbLiKUPrymVFfn99k7/sR+9Z7vCPJV27aHuJ4tb8sY28CekAZle+zns/vtszC1DdzG2Abs8gLbJtOwO7S47OfoabYa0Fdl24eS+tltD55kq7+WP2bzdORUW9Iqyra91PI3284FQ0+BvmNswCrcZvfb+6iD/rnUwdGgoJRqmt8H362BmCRboojvZaumqkpsLzJfjb2oe1Jsm5In2QYasIHHGPtUPthAYgJOdZczjlaPQbZabv3Ldl5cLxswEvvaxvKi7Xa7PY+wjfomYHtbBXy2TcUVBUhdqSU2yQae3sNtTzB/DexcZgNRv7G2u3BVse3ZZwykfc9uy1tpO0V4km0gCwTsMzpJ/ewzMIMm2m1XFNreaCJ2295Km+6qezKZgN/Oc3vscsbY71ExzZeuAn6bB5fbbk/E/r4idd2cjbE3AS53/Z6D7UiDglKq/RgDe9ZD6Xf2QhqTZHtjRcfZALB7jX0w0FdtG7eXz7MN+8NPt6WGgi02iJTttQ8RJvS2PdW+W2u74IrLjkCM2JKHMdhxLZy/3krqDwLRjmqDWlQsGH9dYENsJwwTsPsPeG1yXC97Qa8qsWkxiba3Xf0fDKrL7EXeW2G323B/YEdY9vSwHRlqSu28hD52+YpC22mhNpgkD4DrV9BWh/LhNaVUpBOxDx72q3s/BH1DhoRPn1A33fMIGP2j8LdtnIt/w+FXQnkrbbAo2GoDU1S0bdzvN8Y+B2QC9uLq6WEv6oVZ9o4+Os5eVKuKbQcDfw0MGG+na8psm05NhQ1Sxbn2rj8mwa5bVWzHHXO562+rONdevD3JdtnyQlvV1lBMgi0lxCTYZ44CAZu3gM9uD2zJpmq/7QmX3N+2I5XucoJDmj2ugN9+WtMZogNpUFBKdSyRlhu3o+PsA6B9GxkMLr2Rm98mBogM6uk8YT+kbeMIHc70JTtKKaWCNCgopZQK0qCglFIqSIOCUkqpIA0KSimlgjQoKKWUCtKgoJRSKkiDglJKqaAuOcyFiOQDbX3LTm+goB2z05ki6VhAj6er0+Pp2lo6niOMMWkHu5MuGRQOhohktsf4H11BJB0L6PF0dXo8XduhOh6tPlJKKRWkQUEppVRQJAaFeZ2dgXYUSccCejxdnR5P13ZIjifi2hSUUkq1XSSWFJRSSrVRxAQFEZkmIptFJEtE5nR2fkKJyCAR+UhENorIBhG50UnvJSLvichW529PJ11E5EHnWNaKyPiQbV3uLL9VRC4PSZ8gIuucdR4U6di3s4tIlIisEpG3nO9DRWS5s///iEiMkx7rfM9y5g8J2catTvpmETkrJP2QnksRSRGRl0TkGxHZJCKTuvm5udn5d7ZeRBaIiKe7nR8RmS8ie0VkfUhah5+TpvbRAcfyN+ff21oReVVEUkLmtep3b8u5bZYxptt/gChgGzAMiAHWACM7O18h+esPjHemk4AtwEjgPmCOkz4H+KszfTawGBDgRGC5k94LyHb+9nSmezrzvnKWFWfd6R18TL8GXgDecr6/CFzkTD8KXONMXws86kxfBPzHmR7pnKdYYKhz/qI641wCTwNXO9MxQEp3PTfAQGA7EBdyXq7obucHmAyMB9aHpHX4OWlqHx1wLGcCbmf6ryHH0urfvbXntsX8duR/tkP1ASYBS0K+3wrc2tn5aia/rwNnAJuB/k5af2CzM/0YMCtk+c3O/FnAYyHpjzlp/YFvQtLrLdcB+U8HPgB+ALzl/McqCPlHHjwfwBJgkjPtdpaThueodrlDfS6BHtiLqDRI767nZiCwE3shdDvn56zueH6AIdS/kHb4OWlqH+19LA3mXQA839jv2dLv3pb/ey3lNVKqj2r/I9TKddK6HKcIdxywHOhrjNntzPoO6OtMN3U8zaXnNpLeUe4Hfgc4bycnFdhvjKl9I3ro/oN5duYXO8u39hg7ylAgH3hSbHXYEyKSQDc9N8aYPOD/Ad8Cu7G/90q67/kJdSjOSVP76EhXYUsr0Ppjacv/vWZFSlDoFkQkEXgZuMkYUxI6z9hw3uW7gonIucBeY8zKzs5LO3Fji/aPGGOOA8qx1QZB3eXcADh14DOwwW4AkABM69RMdYBDcU4OxT5E5H8BH/B8R+6nNSIlKOQBg0K+pztpXYaIRGMDwvPGmFec5D0i0t+Z3x/Y66Q3dTzNpac3kt4RTgLOF5EcYCG2CukBIEVE3I3sP5hnZ34PoJDWH2NHyQVyjTHLne8vYYNEdzw3AKcD240x+cYYL/AK9px11/MT6lCck6b20e5E5ArgXOBiJwDRQp4bSy+k9ee2eR1RF3ioP9i7vWzs3VFtI8yozs5XSP4EeAa4v0H636jfqHWfM30O9RvOvnLSe2Hrv3s6n+1AL2dew4azsw/BcU2hrqH5v9Rv7LrWmb6O+o1dLzrTo6jfoJaNbUw75OcS+BT4njN9u3NeuuW5AU4ANgDxzv6eBn7VHc8PB7YpdPg5aWofHXAs04CNQFqD5Vr9u7f23LaY1478z3YoP9geCFuwLfT/29n5aZC3k7HF0LXAaudzNrZ+7wNgK/B+yD9YAR52jmUdkBGyrauALOdzZUh6BrDeWeefhNGg1A7HNYW6oDDM+Y+W5fwjjXXSPc73LGf+sJD1/9fJ72ZCeuQc6nMJjAMynfPzmnMB6bbnBrgD+MbZ57POBaZbnR9gAbZNxIstzf3sUJyTpvbRAceSha3vr70ePNrW370t57a5jz7RrJRSKihS2hSUUkq1Aw0KSimlgjQoKKWUCtKgoJRSKkiDglJKqSANCkoppYI0KCillArSoKCUUiro/wNFyQzlCzFVqQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "complete - train time: 22814s, best epoch: 294, best loss: 0.184430, best accuracy: 95.96%\r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "from paddle.utils.plot import Ploter\n",
    "import numpy as np\n",
    "import time\n",
    "import math\n",
    "import os\n",
    "\n",
    "epoch_num = 300   # 训练周期，取值一般为[1,300]\n",
    "train_batch = 128 # 训练批次，取值一般为[1,256]\n",
    "valid_batch = 128 # 验证批次，取值一般为[1,256]\n",
    "displays = 100    # 显示迭代\n",
    "\n",
    "start_lr = 0.00001                         # 开始学习率，取值一般为[1e-8,5e-1]\n",
    "based_lr = 0.1                             # 基础学习率，取值一般为[1e-8,5e-1]\n",
    "epoch_iters = math.ceil(50000/train_batch) # 每轮迭代数\n",
    "warmup_iter = 10 * epoch_iters             # 预热迭代数，取值一般为[1,10]\n",
    "\n",
    "momentum = 0.9     # 优化器动量\n",
    "l2_decay = 0.00005 # 正则化系数，取值一般为[1e-5,5e-4]\n",
    "epsilon = 0.05     # 标签平滑率，取值一般为[1e-2,1e-1]\n",
    "\n",
    "checkpoint = False                   # 断点标识\n",
    "model_path = './work/out/ssrnet'     # 模型路径\n",
    "result_txt = './work/out/result.txt' # 结果文件\n",
    "class_num  = 10                      # 类别数量\n",
    "\n",
    "with fluid.dygraph.guard():\n",
    "    # 准备数据\n",
    "    train_reader = paddle.batch(\n",
    "        reader=paddle.reader.shuffle(reader=paddle.dataset.cifar.train10(), buf_size=50000),\n",
    "        batch_size=train_batch)\n",
    "    \n",
    "    valid_reader = paddle.batch(\n",
    "        reader=paddle.dataset.cifar.test10(),\n",
    "        batch_size=valid_batch)\n",
    "    \n",
    "    # 声明模型\n",
    "    model = SSRNet()\n",
    "    \n",
    "    # 优化算法\n",
    "    consine_lr = fluid.layers.cosine_decay(based_lr, epoch_iters, epoch_num) # 余弦衰减策略\n",
    "    decayed_lr = fluid.layers.linear_lr_warmup(consine_lr, warmup_iter, start_lr, based_lr) # 线性预热策略\n",
    "    \n",
    "    optimizer = fluid.optimizer.Momentum(\n",
    "        learning_rate=decayed_lr,                           # 衰减学习策略\n",
    "        momentum=momentum,                                  # 优化动量系数\n",
    "        regularization=fluid.regularizer.L2Decay(l2_decay), # 正则衰减系数\n",
    "        parameter_list=model.parameters())\n",
    "    \n",
    "    # 加载断点\n",
    "    if checkpoint: # 是否加载断点文件\n",
    "        model_dict, optimizer_dict = fluid.load_dygraph(model_path) # 加载断点参数\n",
    "        model.set_dict(model_dict)                                  # 设置权重参数\n",
    "        optimizer.set_dict(optimizer_dict)                          # 设置优化参数\n",
    "    else:          # 否则删除结果文件\n",
    "        if os.path.exists(result_txt): # 如果存在结果文件\n",
    "            os.remove(result_txt)      # 那么删除结果文件\n",
    "    \n",
    "    # 初始训练\n",
    "    avg_train_loss = 0 # 平均训练损失\n",
    "    avg_valid_loss = 0 # 平均验证损失\n",
    "    avg_valid_accu = 0 # 平均验证精度\n",
    "    \n",
    "    iterator = 1                                # 迭代次数\n",
    "    train_prompt = \"Train loss\"                 # 训练标签\n",
    "    valid_prompt = \"Valid loss\"                 # 验证标签\n",
    "    ploter = Ploter(train_prompt, valid_prompt) # 训练图像\n",
    "    \n",
    "    best_epoch = 0           # 最好周期\n",
    "    best_accu = 0            # 最好精度\n",
    "    best_loss = 100.0        # 最好损失\n",
    "    train_time = time.time() # 训练时间\n",
    "    \n",
    "    # 开始训练\n",
    "    for epoch_id in range(epoch_num):\n",
    "        # 训练模型\n",
    "        model.train() # 设置训练\n",
    "        for batch_id, train_data in enumerate(train_reader()):\n",
    "            # 读取数据\n",
    "            image_data = np.array([x[0] for x in train_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取图像数据\n",
    "            image_data = train_augment(image_data)                                                        # 使用数据增强\n",
    "            image = fluid.dygraph.to_variable(image_data)                                                 # 转换数据类型\n",
    "\n",
    "            label_data = np.array([x[1] for x in train_data]).astype(np.int64)                        # 读取标签数据\n",
    "            label = fluid.dygraph.to_variable(label_data)                                             # 转换数据类型\n",
    "            label = fluid.layers.label_smooth(label=fluid.one_hot(label, class_num), epsilon=epsilon) # 使用标签平滑\n",
    "            label.stop_gradient = True                                                                # 停止梯度传播\n",
    "\n",
    "            # 前向传播\n",
    "            infer = model(image)\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = fluid.layers.cross_entropy(infer, label, soft_label=True)\n",
    "            train_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            # 反向传播\n",
    "            train_loss.backward()\n",
    "            optimizer.minimize(train_loss)\n",
    "            model.clear_gradients()\n",
    "            \n",
    "            # 显示结果\n",
    "            if iterator % displays == 0:\n",
    "                # 显示图像\n",
    "                avg_train_loss = train_loss.numpy()[0]                # 设置训练损失\n",
    "                ploter.append(train_prompt, iterator, avg_train_loss) # 添加训练图像\n",
    "                ploter.plot()                                         # 显示训练图像\n",
    "                \n",
    "                # 打印结果\n",
    "                print(\"iteration: {:6d}, epoch: {:3d}, train loss: {:.6f}, valid loss: {:.6f}, valid accuracy: {:.2%}\".format(\n",
    "                    iterator, epoch_id+1, avg_train_loss, avg_valid_loss, avg_valid_accu))\n",
    "                \n",
    "                # 写入文件\n",
    "                with open(result_txt, 'a') as file:\n",
    "                    file.write(\"iteration: {:6d}, epoch: {:3d}, train loss: {:.6f}, valid loss: {:.6f}, valid accuracy: {:.2%}\\n\".format(\n",
    "                        iterator, epoch_id+1, avg_train_loss, avg_valid_loss, avg_valid_accu))\n",
    "            \n",
    "            # 增加迭代\n",
    "            iterator += 1\n",
    "            \n",
    "        # 验证模型\n",
    "        valid_loss_list = [] # 验证损失列表\n",
    "        valid_accu_list = [] # 验证精度列表\n",
    "        \n",
    "        model.eval()   # 设置验证\n",
    "        for batch_id, valid_data in enumerate(valid_reader()):\n",
    "            # 读取数据\n",
    "            image_data = np.array([x[0] for x in valid_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取图像数据\n",
    "            image_data = valid_augment(image_data)                                                        # 使用图像增强\n",
    "            image = fluid.dygraph.to_variable(image_data)                                                 # 转换数据类型\n",
    "            \n",
    "            label_data = np.array([x[1] for x in valid_data]).reshape((-1, 1)).astype(np.int64) # 读取标签数据\n",
    "            label = fluid.dygraph.to_variable(label_data)                                       # 转换数据类型\n",
    "            label.stop_gradient = True                                                          # 停止梯度传播\n",
    "            \n",
    "            # 前向传播\n",
    "            infer = model(image)\n",
    "            \n",
    "            # 计算精度\n",
    "            valid_accu = fluid.layers.accuracy(infer,label)\n",
    "            \n",
    "            valid_accu_list.append(valid_accu.numpy())\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = fluid.layers.cross_entropy(infer, label)\n",
    "            valid_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            valid_loss_list.append(valid_loss.numpy())\n",
    "        \n",
    "        # 设置结果\n",
    "        avg_valid_accu = np.mean(valid_accu_list)             # 设置验证精度\n",
    "        \n",
    "        avg_valid_loss = np.mean(valid_loss_list)             # 设置验证损失\n",
    "        ploter.append(valid_prompt, iterator, avg_valid_loss) # 添加训练图像\n",
    "        \n",
    "        # 保存模型\n",
    "        fluid.save_dygraph(model.state_dict(), model_path)     # 保存权重参数\n",
    "        fluid.save_dygraph(optimizer.state_dict(), model_path) # 保存优化参数\n",
    "        \n",
    "        if avg_valid_loss < best_loss:\n",
    "            fluid.save_dygraph(model.state_dict(), model_path + '-best') # 保存权重\n",
    "            \n",
    "            best_epoch = epoch_id + 1                                    # 更新迭代\n",
    "            best_accu = avg_valid_accu                                   # 更新精度\n",
    "            best_loss = avg_valid_loss                                   # 更新损失\n",
    "    \n",
    "    # 显示结果\n",
    "    train_time = time.time() - train_time # 设置训练时间\n",
    "    print('complete - train time: {:.0f}s, best epoch: {:3d}, best loss: {:.6f}, best accuracy: {:.2%}'.format(\n",
    "        train_time, best_epoch, best_loss, best_accu))\n",
    "    \n",
    "    # 写入文件\n",
    "    with open(result_txt, 'a') as file:\n",
    "        file.write('complete - train time: {:.0f}s, best epoch: {:3d}, best loss: {:.6f}, best accuracy: {:.2%}\\n'.format(\n",
    "            train_time, best_epoch, best_loss, best_accu))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "### 模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "infer time: 0.009923s, infer value: horse\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADFCAYAAAARxr1AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAHC1JREFUeJztnXuMXHd1x79nZu689732erO24zixnZg8XGEgFNTybAOqFKhaBH9U+SMCKoFUBP9EVGqp1EpUKqD+UVEFNSKVKIEWUFIaKCGiStNAHiTEeZgkdmJnba937X3NzO687+kfMw47+z17Pdldr3fN+UjW7h7fufd378yZe89bVBWO49jELvcCHGcz4wriOBG4gjhOBK4gjhOBK4jjROAK4jgRuII4TgSuII4TwZoURERuE5GXROSYiNy1XotynM2CrDaSLiJxAC8D+CCAUwCeBPAJVX1xpdfEg4QGqaBDlslkaLtYTEhWqVRIptowjxOG/Pp4LE6yIJkgWT7XY+2RJLXFKskaTd5OY/b1jRtfTf2D20iWyeR4NSHvs1xeNGQLJEun+XpjhY9Ao9kkWRAEJEum0vYO6DB8IDWurSrLwpBlANA0rnmjXiNZPN75/k9PnkepUOQPyjL4E9I9bwdwTFVfBQARuQ/A7QBWVJAgFWDnjXs7ZDfd9BbaLpNJkuyV40dJVq/NmseplPj1OeODv2v3MMnecevvkUxCvuCnnjlOsvOzJZI1svxaAMjz5x5//PE/J9lbbnonycqLvM8jzz9FsueeY9nBA3y9LYUDgNm5Asm2je4k2dV795FMhL8B6k3+UqmDv/iqTVbshcWyucbSLMunz46TrL833/H3333+S+b+lrOWR6wxAEtXcqot60BEPiUiT4nIU806fyM5zmbmkhvpqnq3qh5W1cPxgB9zHGczs5ZHrNMAdi35e2dbtiIahqhXOm+JszPnaLvkju0kGx1h2bmpunmcuXN8i85m2V5ZWJwh2anTL5NsbIRujEinUyRTnSfZ4NCAuca3HtrPxxkdJVnTeJ6ePc/X7OWjL5FsboYfkRYW+HGqt99eY98wPwdWQ76O52fPkyyMZ0kmYtkLbDtVq/yoWq7aj6qLNd5nkOZ1yzL756LGR5u13EGeBLBPRK4RkSSAjwN4YA37c5xNx6rvIKraEJHPAvhvAHEA96jqC+u2MsfZBKzlEQuq+iCAB9dpLY6z6fBIuuNEsKY7yJslnohjaLivQ1avsh87YQTXioU5klWMeAAADA8MsWw4T7KbbmGfPowg5dnJUyTLBmykZ7MchGs07DWm0xyrSSX42DHl2EEixobywQPXkWz3VVeRLNfDwcggawVHgcoiOx1qjSLJ5ubZ2TGzwO+X5XAIhM9lscD7q4Z2iKBoxA8b81Mkyy6Lt1SNwLOF30EcJwJXEMeJwBXEcSJwBXGcCDbUSE+nUrju2j0dsiDJ6SepFMvm56dJ1qzb2bxqJBcOD/WT7PoDe0k2PcsJkM8+8wqvMdhBsr7+XpIVY2ysAnZUOQZedyrBstERjhTnMntI9sJznBXQk+HM2zDJDgcAiNVYnkuwsZxK8vdsocTnVyxwhoOEbCxXipwBsNiwsyZmjIzjsMjv4WK9MzrfWOGzsxy/gzhOBK4gjhOBK4jjROAK4jgRbKyRnk7hwP5rO2S5HEef6w023IoFjo42a3Z0dWqCX79jhKsHBw2jenqanQGJBDsNjII5hIYhGQ/s76Aw5G1j4KyCVJIj6VYFYCPNsljIr80aRvZ8jVPOAaA4z8ZyT9pwLoR8jqkGy3LCH7eZeY7MWxF3GE4NAIgpX8dymdPlK4VOw90qJzb339VWjvNbiiuI40TgCuI4EbiCOE4EriCOE8GavFgicgJAEUATQENVD19ke6SWNWtLpXgJVg+rgzccIFk6sJd/8jinGoyNctOHIMGvtxrMZY30jHiTv1uGhvpIFvbY3pJUiutBmk32vtSqRhO8uLFuw0OUCtjzk0/yeqbO83EB4PzZsyyrsmercJZ7dZQWjfUY1/HE+DGS7drDdSw92zlVCADiFfZ4FYxGIPW5ztqWZqO7VJP1cPO+V1W5rYXjXAH4I5bjRLBWBVEAPxGRX4rIp6wNlnZWXCjZASnH2ays9RHr3ap6WkS2A3hIRH6tqo8s3UBV7wZwNwCM7R71mdPOlmKtbX9Ot39OicgP0Gpo/chK28cTCfQNdqZ81BucXlGscPrBVbvYcOtL2c0GqsUTJBscYCMvFuMbaCbLxvPAIDd86AlZtnPn1SQLM7YxmO8zajCMdn+lBW6cEDcaSzRqnLJjpchUjQ6FzzzxS3ONP3/61ySrL/B7M33mdX5xjGtWtg9xus/4aTbSc3l+X7btto30nGFsx4x6oHBZR8hupxqs+hFLRHIi0nPhdwB/AOD51e7PcTYja7mDjAD4gYhc2M+/qeqP12VVjrNJWEvr0VcB3LKOa3GcTYe7eR0ngg2tB4kl4sgNDXbIzkycoO3OzXPccWgHt+iPiz36K2M0Ieg3WvxncmyQDxl1I8UFrqvIGEb60DAfI8jbl1gDNi6TaXY61OtsFFtdAWtlNtwLJa6VeOKJX5DsJw/afpXZOTZ2G4aRX2uywRsPeLuRPv4+Tjf5PTh/hiP4e27gJhkA0GvU2+SM0XoVGvV2iY10x/ltwBXEcSJwBXGcCFxBHCeCDTXSIYBkOtOtYxk2LhNGNHtukY1VNYxDAIgbowV6BzgSn+/laO9rpzhV/twkp4hnhQ3gA9ewkT02yvMNAWCxyeeTSPB6YtY45TpHyE+f4pEBD/2EDfJnn+Xo+My0nSOXTnH6ftMYV2AFpStlw7mwwGn1mRjPXS+e59eePn7GXOPYGI9zyKZ4PuJcrHOf1phqC7+DOE4EriCOE4EriONE4AriOBFsbCQ9rgjyncbtjt2DtF3/do5Sq5XWXLaN9FyMU6Mnpzmye+Qod2t87DFOSD53miP7KWXjcnGGL+d7/9AeLbDnwC6SBXHeJ2KcLTB+jlPgf/xfPyfZIz/jNPZa0xgjYJwLANTrvG29wVkFVaM8oVbj2vXZWXYu5AM+v1jduLbG+wcAsxk+djLJRnoq0/mZEqPUwcLvII4TgSuI40TgCuI4EbiCOE4EFzXSReQeAH8EYEpVb2zLBgF8B8AeACcAfExVOQS9jFAbqDQ6N0ulOXocpA0jK8HGbio0irgBLM5wJP34S9xM7MXnx0k2foKN0ECHSDY7x4byo2d/xdsVuOYeAMb2cvp2b5bPJybsiDh69FWSPf5/L5Bs0QiQJ1NsFDdWyEioGvMDq1WeM1iu8BzGWIJfO7c4SbIwxSMotvezAyMd8GcCAAaG2MlTNcZQ9IWd6e4Jo2mgRTd3kG8CuG2Z7C4AD6vqPgAPt/92nCuOiypIu43P8kSf2wHc2/79XgAfWed1Oc6mYLU2yIiqTrR/P4tWAweTpY3jinN8e3aczcyajXRtNRhasX5RVe9W1cOqerinn+0Nx9nMrDaSPikio6o6ISKjADgkbdBo1HFuqrMT+LZhTkNv1NlYLRr914KQI7MAcPQZnjOICjd1G+jlVPT5XjbSk0Y0e7ZiNG9juxTT5+xU8mdffJRkxblTJDO/wYzIdyLOqfZ9vbzuUDl1X41GawBQq7ODoW7Jmvxk8NZ33UiyXaOjJHv28SMkKzU5Cl+Hvcbd+/k9jBkTA0bLne/rI//5U3N/tK+utmIeAHBH+/c7ANy/yv04zqbmogoiIt8G8HMAB0TklIjcCeDLAD4oIq8A+ED7b8e54rjoI5aqfmKF/3r/Oq/FcTYdHkl3nAg2NN292WiiNNtpgMUavIRGg0eHzc5w5Lo4YxtuR37BUfMbr+URbHWjIdz0JDct6+/lZnK5PHdTLysb5PkcR3oBIDHN9dnlEn9fNRt8jrkcHzubZsM9k+GMgmKRzzkmfL0BIJ1mI79mOFD6hrk84a3vexvJ9h+4hmQDu7jZ3q+feY1khYadqCF5jppXxYjilzqj/U3tbgSb30EcJwJXEMeJwBXEcSJwBXGcCFxBHCeCjfVi1ZuYmez0JhSmOU2hYTRomJpkz1S5YKeaTIyzx6t6/imShVVOISsWjayZkNMzrtrOnq258+zFajT4tQDQ28OvX+znGpGFItdaiPG2NZrs7Yon2OMUGKMhEgl7jERvH3vB4lPcJGH3wb0ku/7wQZJl8vx+Hfr9m0m2fYzTj8oLdqJrWYzmEMZcx3OLndexERrNKwz8DuI4EbiCOE4EriCOE4EriONEsKFGepAMMDq2s0MWj/MSyiVOw0iBDbyzdTbQAKBhGHQnTj7N60kY8+16uYlAIsWGdt0YX1AqcR1KKrXfXOPeEa5PqRsNERp1TgOxjO+E0RwxNL7+8oOcFjKQ5RECAJBO8XGu2suG++9+6ADJ9ozy6ISGsmGsed5fj5XGU7WbX4QBf1Z6E7zPhnTuM2FcQwu/gzhOBK4gjhOBK4jjROAK4jgRrLaz4pcAfBLAhfD2F1X1wYvtK5VO49r9+zr3H2PjOy28rJRRsjD+GnfqA4AnHzLqCco8w8/q+t8/xMZlT54N26mJCZLVa5wB0GzatRbWPnft3E0y6xtMwcauGDUdxQWOwvf0sAE8Mmwb6YFxfXbdzI0Xxq7jcRNiRPaTcX6v40k+iHXNYkn7OlaMcQ75OGdIxJcdO2E4hyxW21kRAL6mqofa/y6qHI6zFVltZ0XH+a1gLTbIZ0XkiIjcIyJcN9lmaWfFwqx3VnS2FqtVkK8DuBbAIQATAL6y0oZLOyv2DnhnRWdrsapIuqq+YR2LyDcA/LCb1zWbTRSLnZHPhGGklSps4KWNsPDcrB1Jr1aM2Xp1Y8ahFUnPGI0KFrgJwNkz3NxBjZl+J4+fNNdYNjoz5rOcij62g9PiFRzZr1Q51f7ka6+QLJviYyyk7Sfo/DB/oSX7uIPjbJmPrWU2qhPGDMbAMNyThndAQnuOYgzGWITQyD6QZe9/d4H01d1B2u1GL/BRADz50nGuALpx834bwHsADIvIKQB/DeA9InIIrabVJwB8+hKu0XEuG6vtrPgvl2AtjrPp8Ei640SwoenuoSrK1U6DN6yyAZxPGnP0jLrwhUU7BbpaZWO5abzeqhefL7DBGoYcmZ2b5/mGFcNYDWPcyRAAEkZkuGGk2sfFqivnt21+jjsPnp/kaH82zt+J9Yp9Hfvi7L3fBc4ACMt8fZqGoZywMgCMqHnGSFdvjaFhmgHLJcay2LKMDVFPd3ecNeMK4jgRuII4TgSuII4TwYYa6QAgyyKnWaPFfjpgIy2jrMuJuDGLEECzueJM0Q7qxmiB6WmOkI/tYsP0lrdxczOJs+F3Zpwb3gHAyVNPkiyTYaO4boyCSKf4mtWNKH61xpkGxTmOpDcCY7gigL4UR6nDODs2qg2+3g3DSA/rfL1jhpFeEz6X8qK9Rg34mieTnDWRS3duZzleLPwO4jgRuII4TgSuII4TgSuI40SwoUa6GpH0uhqpyYb9FMTZcIdRzw4AYjSZswKxahiSuTwbwB/90w+QbO8NbKTHU7zGI0+9YK7xpz96lGTnjbmFzTobxc2YUfsOI3KdYdlChQ337UOcUg8AB27m5na9/Wy4F6rcRM+aAVirsaFdN0obrBT4utHxHwDCqpXaztkL1WWN4pordN1fjt9BHCcCVxDHicAVxHEicAVxnAi6qSjcBeBfAYygVUF4t6r+o4gMAvgOgD1oVRV+TFXtae9tFECITmu50WTjS40UbyvwWavahlutyoafsUuI8fWwbz+PE9t/kLuXZ4b50tWUDb/D736Hucb9+99CsvEzZ0h25ixH9qF8bA35BH90/0Mkm3iZR8wNjnCaPQBs38kN4RZKPN4sbPJ5xwI2nkWMN9H4BBaMLvdN4xgAkBI26GOGR6Za7vysrGckvQHgC6p6EMCtAD4jIgcB3AXgYVXdB+Dh9t+Oc0XRTeO4CVV9uv17EcBRAGMAbgdwb3uzewF85FIt0nEuF2/KBhGRPQB+B8DjAEZU9ULJ2lm0HsGs17zROG5hnivuHGcz07WCiEgewPcAfE5VO6JN2qqHNB/qljaOy/UZPYwcZxPTlYKISICWcnxLVb/fFk9e6I/V/mkMGHecrU03XixBq83PUVX96pL/egDAHQC+3P55fxf7QrC8a55RQ9E00k+MjAvMTdu9fms1ozmA0azAcm3tNobYZ1PcYXBhntNCakZjiOQKLfwG8lz70bOfuxaODA2SzPLAGGUjeOx/HiPZuOE17OmxuxbmjIYIzRKncRiTF8y0m5rR/TFjdVFM8zWrGesGgIRx4mJ8fpZ7T1d44OH9d7HNuwD8GYDnRORXbdkX0VKM74rInQBOAvhYV0d0nC1EN43jHsXKnUzfv77LcZzNhUfSHScCVxDHiWBjmzaoUlqC1cp/ocx6qyHXaVTmbUMrtOYCxvgpMWs0Jbj6Kp7BF6vwGmWBj50MOe0hZRW3AAgCdiRk49wcIhXjc6k0uKaj1GDjOWk4NpJxbtpwwzVj5hr3DhqpJkleY8nocDlvzGusFK2RCMZ7ZaSkiJUXBHPSAerG3MJmo3ON4QqdGpfjdxDHicAVxHEicAVxnAhcQRwngg1v2tCsdBqTGhhjCWrGDL4iyyZP210L1Yg0G+US6O/rI9mQIZs7a3RwDPnSpRNs9DcadkfAWNwwqo13I6izAVwpcwaBGk0S1Oh4GBhNKQa28zkDQDrNC4oLG/mG/wN9yhHyoQbLJie43qVhOBwq4QpNGwzj3TL8q6Vl74OVemDgdxDHicAVxHEicAVxnAhcQRwngg010mMKpJbZ2rGAo89loxnDzAQb5Oen7BIUK7PS6heQy7JRvVBkA/i1l06QLBnj1w71cYfChDEvDwDCpJFKXj1PMjWM/KqVX57hlPy48NubyhkNH8RuiLBY4JR+VTbSYaXQxzMkS6d5jWEPOwikOMfHLdtGdWG58Q0gnTQyMYqdn7NY02cUOs6acQVxnAhcQRwnAlcQx4lgLZ0VvwTgkwAuWM9fVNUHo/YVgyAbdkZTa2WOmqLE6cqL57ijX21hhSi1YaaLUX8exI3uiGXepxjp94UK11c3Cmzs9vWyYQoAkjVq8RdneJ81Pk49yfXe8STXs4sR7e/NGPMf67YBXJ/itHoFG+lqpOTPhvwe1ow5ijEjWp8O2cCPG+cMAIkGX8d4jR0g8WrnumPanZHejRfrQmfFp0WkB8AvReRCT8uvqeo/dHUkx9mCdFOTPgFgov17UUQudFZ0nCuetXRWBIDPisgREblHRLiPDTo7K5YK3lnR2VqspbPi1wFcC+AQWneYr1ivW9pZMd/rnRWdrUVXkXSrs6KqTi75/28A+OFFd9QEYsuCs0GCI+m94LrnoMLR1WqJU8EBIB5jvQ+SbFymAjb80sZ2PSlOES812blQmOdZfcV5eyJEIsavv26Uo8qZPF+LQoGj/fNnOKtgbpaN7IEcG8ADwjIAyC4YWQ7GTMHQmIVYXp4yAaBujJRU4/2PG3XziRWy061ov3FpkdTO91pW7GS1bF8X22ClzooX2o62+SiA57s6ouNsIdbSWfETInIILdfvCQCfviQrdJzLyFo6K0bGPBznSsAj6Y4TwcamuyOGnHYanaEx1y+RZG/X7iFeal/uFfM4QYIN46TRRTxhDKxfWGCjr1nj6HqQYMO2Z4Cj2aFRXw0AiwuGgyHPRnoixw6CsMIW6+lxNtKnZ9mxMTy8m2SBER0HgGST5ZUqn0+xzBkAlV6+tknj+sTT/L7US+wIaNTt62hlJKiRYLE84N5d2zi/gzhOJK4gjhOBK4jjROAK4jgRbLiRntJOAzxmdD/XGhtuuYANvH3X7jOP0zQisePjp0lWKnDk+9grx3h/Rl14Is7G8+g2zuHsMYxsANi2nUer1ZJs+JeEjV1N83YNI50/leXtakaWQV3sVHKNG6n6Rnc7aXBkv17kKH4s4NcGxrp7jcyFecNRAgDxPl57o2o0I5xflmpvHNfC7yCOE4EriONE4AriOBG4gjhOBK4gjhPBxs4oBKDxi8f800b6wZ69O0nWv2ObeYx3LrLH49H/fYxkrx1nj1WtZgyhr3KzgWrIqQ+v13i7dMZO45DcdSTLp4zzMboEJvvZO9U7xJ6fwSE+9sIie5ymZ20P0eCeHSQLA15PptpLsqDGXqJKhWVVNeYJZvlDsWCMdwCAsuGMSuc5VWm5E1SsmQ0GfgdxnAhcQRwnAlcQx4mgm5LbtIg8ISLPisgLIvI3bfk1IvK4iBwTke+IrBCOdZwtTDdGehXA+1S11G7e8KiI/AjA59FqHHefiPwzgDvR6nSyIqqKeqPT2LLqNNIpNi6DGG+X7eGGBgCwQ1nvt/fdRrLx1zn9pDfPKS1Tr79Gsvl5rrUI8mw8V0N7tEDd8E5UjHmEGuO3qFTiFJmEkbJz4/XXkKxvgJ0dg/08tgEApit8nB1Xs+FeOM0Oi8lZHuVQqLOhXTMcMlpnAzrM2N/lqQR/VqZnuT7l9MlTHX9XalxzYnHRO4i2uFBFFLT/KYD3AfiPtvxeAB/p6oiOs4XoygYRkXi7YcMUgIcAHAcwp78ZrXoKK3RbXNo4rlBiF6PjbGa6UhBVbarqIQA7AbwdwPXdHmBp47jevN3I2XE2K2/Ki6WqcwB+BuCdAPpF3pjxtRMAP9A7zhanm/EH2wDUVXVORDIAPgjg79FSlD8BcB+AOwDcf9GjqSBe7zxkLMFLiBk1CzWjaD9pl1pAjOHyO4c4Sj02OEKywjyPWehL8P5mZqdJVrIaCxj1EwAQGl9NWmFDu2p0MkwbcwL7R7gZw/49fKOvG0a/Vm2DNd/LDot8H1/06gzXrFTBEfIzU5MkKxmG++AY18oEeTbmASAEX/OUUQ/UM9DZOjpuNOyw6MaLNQrgXhGJo3XH+a6q/lBEXgRwn4j8LYBn0Oq+6DhXFN00jjuCVkf35fJX0bJHHOeKxSPpjhOBK4jjRCCq3faYW4eDiZwDcBLAMAAOtW5N/Fw2Jxc7l6tV1a6XWMKGKsgbBxV5SlUPb/iBLwF+LpuT9ToXf8RynAhcQRwngsulIHdfpuNeCvxcNifrci6XxQZxnK2CP2I5TgSuII4TwYYriIjcJiIvtUt179ro468FEblHRKZE5PklskEReUhEXmn/HIjax2ZBRHaJyM9E5MV2KfVftOVb7nwuZVn4hipIO+HxnwB8CMBBtCblHtzINayRbwJYXrt7F4CHVXUfgIfbf28FGgC+oKoHAdwK4DPt92Irns+FsvBbABwCcJuI3IpW1vnXVPU6ALNolYW/KTb6DvJ2AMdU9VVVraGVKn/7Bq9h1ajqIwCWFzzfjlbJMbCFSo9VdUJVn27/XgRwFK2q0C13PpeyLHyjFWQMwPiSv1cs1d1CjKjqRPv3swC4yGSTIyJ70MrYfhxb9HzWUhYehRvp64i2fOZbym8uInkA3wPwOVXtmHqzlc5nLWXhUWy0gpwGsGvJ31dCqe6kiIwCQPsnz2PepLTbOH0PwLdU9ftt8ZY9H2D9y8I3WkGeBLCv7V1IAvg4gAc2eA3rzQNolRwD3ZYebwJERNCqAj2qql9d8l9b7nxEZJuI9Ld/v1AWfhS/KQsHVnsuqrqh/wB8GMDLaD0j/uVGH3+Na/82gAkAdbSeae8EMISWt+cVAD8FMHi519nlubwbrcenIwB+1f734a14PgBuRqvs+wiA5wH8VVu+F8ATAI4B+HcAqTe7b081cZwI3Eh3nAhcQRwnAlcQx4nAFcRxInAFcZwIXEEcJwJXEMeJ4P8BDOlRG0/6AGEAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 216x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle.fluid as fluid\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "image_path = './work/out/img.png' # 图片路径\n",
    "model_path = './work/out/ssrnet-best' # 模型路径\n",
    "\n",
    "# 加载图像\n",
    "def load_image(image_path):\n",
    "    \"\"\"\n",
    "    功能:\n",
    "        读取图像并转换到输入格式\n",
    "    输入:\n",
    "        image_path - 输入图像路径\n",
    "    输出:\n",
    "        image - 输出图像\n",
    "    \"\"\"\n",
    "    # 读取图像\n",
    "    image = Image.open(image_path) # 打开图像文件\n",
    "    \n",
    "    # 转换格式\n",
    "    image = image.resize((32, 32), Image.ANTIALIAS) # 调整图像大小\n",
    "    image = np.array(image, dtype=np.float32) # 转换数据格式，数据类型转换为float32\n",
    "\n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    image = (image/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    image = image.transpose((2, 0, 1)).astype(np.float32) # 数据格式从HWC转换为CHW，数据类型转换为float32\n",
    "    \n",
    "    # 增加维度\n",
    "    image = np.expand_dims(image, axis=0) # 增加数据维度\n",
    "    \n",
    "    return image\n",
    "\n",
    "# 预测图像\n",
    "with fluid.dygraph.guard():\n",
    "    # 读取图像\n",
    "    image = load_image(image_path)\n",
    "    image = fluid.dygraph.to_variable(image)\n",
    "    \n",
    "    # 加载模型\n",
    "    model = SSRNet()                               # 加载模型\n",
    "    model_dict, _ = fluid.load_dygraph(model_path) # 加载权重\n",
    "    model.set_dict(model_dict)                     # 设置权重\n",
    "    model.eval()                                   # 设置验证\n",
    "    \n",
    "    # 前向传播\n",
    "    infer_time = time.time()              # 推断开始时间\n",
    "    infer = model(image)\n",
    "    infer_time = time.time() - infer_time # 推断结束时间\n",
    "    \n",
    "    # 显示结果\n",
    "    vlist = [\"airplane\", \"automobile\", \"bird\", \"cat\", \"deer\", \"dog\", \"frog\", \"horse\", \"ship\", \"truck\"] # 标签名称列表\n",
    "    print('infer time: {:f}s, infer value: {}'.format(infer_time, vlist[np.argmax(infer.numpy())]) )\n",
    "    \n",
    "    image = Image.open(image_path) # 打开图像文件\n",
    "    plt.figure(figsize=(3, 3))     # 设置显示大小\n",
    "    plt.imshow(image)              # 设置显示图像\n",
    "    plt.show()                     # 显示图像文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PaddlePaddle 1.8.4 (Python 3.5)",
   "language": "python",
   "name": "py35-paddle1.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
